Hi!
I have a hypercube of side 2 and n dimensions, i.e. an array with shape (2,2,...,2) (with n 2’s). I don’t know the value of n in advance. I need a function that splits the array at an arbitrary axis. In Numpy it would look like this:
def split_hypercube_v1(arr, axis):
return np.take(arr, 0, axis), np.take(arr, 1, axis)
But numba does not support the axis argument for np.take.
This does not work either:
def split_hypercube_v2(arr, axis):
a,b = np.split(arr, 2, axis)
return np.squeeze(a), np.squeeze(b)
because Numba does not support np.squeeze.
This does not work either:
def split_hypercube_v3(arr, axis):
tmp = np.moveaxis(arr, axis, 0)
return tmp[0], tmp[1]
because Numba does not support moveaxis. I thought of using np.transpose, but I can’t make an arbitrary tuple.
I feel kinda stuck, is there an obvious solution that I’m missing? Any help is greatly appreciated!
from numba import njit
from numba.cpython.unsafe.tuple import tuple_setitem, build_full_slice_tuple
@njit
def split_hypercube(arr, axis):
ndim = arr.ndim
slice_tup = build_full_slice_tuple(ndim)
for i in range(ndim):
if i == axis:
tuple_setitem(slice_tup, i, slice(0, 1))
else:
tuple_setitem(slice_tup, i, slice(None))
return arr[slice_tup]
Can you try something like this?
hey thanks for your answer. I tried but it doesn’t seem to work:
arr = np.random.random((2,2,2,2))
a,b = split_hypercube(arr, 1)
assert np.allclose(a, arr[:,0])
assert np.allclose(b, arr[:,1])
>>> AssertionError
My example won’t work as is. You’ll have to modify it. First off, it only returns one thing. You’ll have to add the code for the “b” part. Also, I think your code may require the dimension of the output array to shrink which won’t happen in this case since the part selecting the 0 or 1 slice is itself a slice and so you’ll still have that dimension in the output. You’ll have to figure out how to get rid of that or deal with it as it is.