Why is this code failing? (generating all the indices of an array)

I don’t understand the reason why this code fails and how to fix it. What I want to do is to generate all of the tuples of indices of an array of shape `(s1, s2, ..., sn)`, where `n` is not known in advance.
Example:

``````all_indices(shape=(1,2,3))
>>> [(0,0,0), (0,0,1), (0,0,2), (0,1,0), (0,1,1), (0,1,2)]
``````

This is my solution:

``````def all_broadcasted_arange(shape):
'for s_i in shape, return arange(s_i) broadcasted to i-th axis'
for i,s in enumerate(shape):
r = np.arange(s, dtype=np.int32)
for k in range(i):
r = np.expand_dims(r, axis=0)
for k in range(len(shape)-i-1):
r = np.expand_dims(r, axis=-1)
yield r

def all_indices(shape: list):
'multiply ones((s_1...s_n)) by arange(s_i) on the i-th axis, then stack and reshape.'
coords = []
coords.append(np.ones(shape, dtype=np.int32)*r)
return np.reshape(np.stack(coords, axis=-1), (-1, len(shape)))
``````

When I `@njit` the first function, I get the error

``````TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(int32, 1d, C) and array(int32, 2d, C) for 'r.3', defined at ...
``````

which I think it means that numba is unhappy with `r` being redefined? How can I get around it? The problem is that `np.expand_dims` doesn’t support a list/tuple axis (so I have to use loops?), and `np.meshgrid` which would essentially would output `ones*r` on each axis is not supported.

Hi @ziofil,

The reason the code won’t compile at present is that is isn’t type stable. It’s a bit like doing this:

``````@njit
def foo(n):
r = np.arange(n) # r is 1D
for k in range(1): # Numba doesn't know if this loop will execute
r = np.expand_dims(r, axis=0) # r is 2D
return r # What's r? 1D or 2D?

foo(10)
``````

I think this will do what you are looking for:

``````from numba import njit
import numpy as np

@njit
def all_indices(shape):
return [x for x in np.ndindex(shape)]

print(all_indices(shape=(1, 2, 3)))
``````

which gives:

``````\$ python <elided>.py
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 1, 2)]
``````

Hope this helps?

1 Like