Why is this code failing? (generating all the indices of an array)

Hi @ziofil,

The reason the code won’t compile at present is that is isn’t type stable. It’s a bit like doing this:

@njit
def foo(n):
    r = np.arange(n) # r is 1D
    for k in range(1): # Numba doesn't know if this loop will execute
        r = np.expand_dims(r, axis=0) # r is 2D
    return r # What's r? 1D or 2D?

foo(10)

I think this will do what you are looking for:

from numba import njit
import numpy as np

@njit
def all_indices(shape):
    return [x for x in np.ndindex(shape)]

print(all_indices(shape=(1, 2, 3)))

which gives:

$ python <elided>.py 
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 1, 2)]

Hope this helps?

1 Like