I don’t understand the reason why this code fails and how to fix it. What I want to do is to generate all of the tuples of indices of an array of shape (s1, s2, ..., sn), where n is not known in advance.
Example:
all_indices(shape=(1,2,3))
>>> [(0,0,0), (0,0,1), (0,0,2), (0,1,0), (0,1,1), (0,1,2)]
This is my solution:
def all_broadcasted_arange(shape):
'for s_i in shape, return arange(s_i) broadcasted to i-th axis'
for i,s in enumerate(shape):
r = np.arange(s, dtype=np.int32)
for k in range(i):
r = np.expand_dims(r, axis=0)
for k in range(len(shape)-i-1):
r = np.expand_dims(r, axis=-1)
yield r
def all_indices(shape: list):
'multiply ones((s_1...s_n)) by arange(s_i) on the i-th axis, then stack and reshape.'
coords = []
for r in all_broadcasted_arange(shape):
coords.append(np.ones(shape, dtype=np.int32)*r)
return np.reshape(np.stack(coords, axis=-1), (-1, len(shape)))
When I @njit the first function, I get the error
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(int32, 1d, C) and array(int32, 2d, C) for 'r.3', defined at ...
which I think it means that numba is unhappy with r being redefined? How can I get around it? The problem is that np.expand_dims doesn’t support a list/tuple axis (so I have to use loops?), and np.meshgrid which would essentially would output ones*r on each axis is not supported.