Why is this code failing? (generating all the indices of an array)

I don’t understand the reason why this code fails and how to fix it. What I want to do is to generate all of the tuples of indices of an array of shape (s1, s2, ..., sn), where n is not known in advance.
Example:

all_indices(shape=(1,2,3))
>>> [(0,0,0), (0,0,1), (0,0,2), (0,1,0), (0,1,1), (0,1,2)]

This is my solution:

def all_broadcasted_arange(shape):
    'for s_i in shape, return arange(s_i) broadcasted to i-th axis'
    for i,s in enumerate(shape):
        r = np.arange(s, dtype=np.int32)
        for k in range(i):
            r = np.expand_dims(r, axis=0)
        for k in range(len(shape)-i-1):
            r = np.expand_dims(r, axis=-1)
        yield r

def all_indices(shape: list):
    'multiply ones((s_1...s_n)) by arange(s_i) on the i-th axis, then stack and reshape.'
    coords = []
    for r in all_broadcasted_arange(shape):
        coords.append(np.ones(shape, dtype=np.int32)*r)
    return np.reshape(np.stack(coords, axis=-1), (-1, len(shape)))

When I @njit the first function, I get the error

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(int32, 1d, C) and array(int32, 2d, C) for 'r.3', defined at ...

which I think it means that numba is unhappy with r being redefined? How can I get around it? The problem is that np.expand_dims doesn’t support a list/tuple axis (so I have to use loops?), and np.meshgrid which would essentially would output ones*r on each axis is not supported.

Hi @ziofil,

The reason the code won’t compile at present is that is isn’t type stable. It’s a bit like doing this:

@njit
def foo(n):
    r = np.arange(n) # r is 1D
    for k in range(1): # Numba doesn't know if this loop will execute
        r = np.expand_dims(r, axis=0) # r is 2D
    return r # What's r? 1D or 2D?

foo(10)

I think this will do what you are looking for:

from numba import njit
import numpy as np

@njit
def all_indices(shape):
    return [x for x in np.ndindex(shape)]

print(all_indices(shape=(1, 2, 3)))

which gives:

$ python <elided>.py 
[(0, 0, 0), (0, 0, 1), (0, 0, 2), (0, 1, 0), (0, 1, 1), (0, 1, 2)]

Hope this helps?

1 Like