I have a question about how Numba can handle functions returning arrays of different dimensions based on runtime conditions.
Can Numba efficiently compile and optimize functions with dynamic output dimensions based on its arguments (i.e returning 1D, 2D…ND arrays)?
Here is a simple example:
import numpy as np
from numba import njit
@njit(['int64[:](boolean)', 'int64[:,:](boolean)'])
def func(is_1d: bool) -> np.ndarray[np.int64]:
if is_1d:
return np.array([1,2,3])
return np.array([[1,2,3], [1,2,3]])
print(func(is_1d=True))
print(func(is_1d=False))
# TypingError: No conversion from array(int64, 2d, C) to array(int64, 1d, A)
# Workaround passing array in desired shape
@njit(['int64[:](int64[:], boolean)', 'int64[:,:](int64[:,:], boolean)'])
def func_v2(dummy: np.ndarray[int], is_1d: bool) -> np.ndarray[np.int64]:
out = np.empty_like(dummy)
if is_1d:
out[:] = np.array([1,2,3])
else:
out[:] = np.array([[1,2,3], [1,2,3]])
return out
print(func_v2(dummy=np.empty(shape=3, dtype=np.int64), is_1d=True))
# [1 2 3]
print(func_v2(dummy=np.empty(shape=(2,3), dtype=np.int64), is_1d=False))
# [[1 2 3], [1 2 3]]
No, this is not possible because the number of array dimensions is part of the type and Numba must know all types at compile time. But this also means that if you can somehow infer the number of dimensions from the argument types, you can make it work. See for example this case where we derive the dimension from the input array:
I just wanted to say thanks for your creative idea. Your workaround using @numba.extending.overload and @numba.njit could be a way for handling dynamic output dimensions. Your input is greatly appreciated!
For my personal taste, passing the output array and modifying its content seems to be a more straightforward way as long as Numba does not support other techniques.