Handling Functions with Dynamic Array Dimensions

Hey Numba Ninjas,

I have a question about how Numba can handle functions returning arrays of different dimensions based on runtime conditions.
Can Numba efficiently compile and optimize functions with dynamic output dimensions based on its arguments (i.e returning 1D, 2D…ND arrays)?

Here is a simple example:

import numpy as np
from numba import njit


@njit(['int64[:](boolean)', 'int64[:,:](boolean)'])
def func(is_1d: bool) -> np.ndarray[np.int64]:
    if is_1d:
        return np.array([1,2,3])
    return np.array([[1,2,3], [1,2,3]])

print(func(is_1d=True))
print(func(is_1d=False))
# TypingError: No conversion from array(int64, 2d, C) to array(int64, 1d, A)


# Workaround passing array in desired shape

@njit(['int64[:](int64[:], boolean)', 'int64[:,:](int64[:,:], boolean)'])
def func_v2(dummy: np.ndarray[int], is_1d: bool) -> np.ndarray[np.int64]:
    out = np.empty_like(dummy)
    if is_1d:
        out[:] = np.array([1,2,3])
    else:
        out[:] = np.array([[1,2,3], [1,2,3]])
    return out

print(func_v2(dummy=np.empty(shape=3, dtype=np.int64), is_1d=True))
# [1 2 3]
print(func_v2(dummy=np.empty(shape=(2,3), dtype=np.int64), is_1d=False))
# [[1 2 3], [1 2 3]]

@Oyibo

No, this is not possible because the number of array dimensions is part of the type and Numba must know all types at compile time. But this also means that if you can somehow infer the number of dimensions from the argument types, you can make it work. See for example this case where we derive the dimension from the input array:

import numpy
import numba

def foo(x):
    pass 

@numba.extending.overload(foo)
def foo_impl(x):
    shape = (2,) * x.ndim
    
    def impl(x):
        return numpy.ones(shape)

    return impl 

@numba.njit
def example(a):
    return foo(a)
        
a = numpy.empty((2, 2, 2))
print(example(a))

Or this case where the input variable is a literal:

import numpy
import numba

def foo(x):
    pass 

@numba.extending.overload(foo)
def foo_impl(is_1d):
    is_1d_literal_value = is_1d._literal_value 
    
    def impl(is_1d):
        if is_1d_literal_value :
            return numpy.ones(2)
        return numpy.ones((2, 2))
            
    return impl 

@numba.njit
def example1():
    return foo(True)

@numba.njit
def example2():
    return foo(False)

print(example1())
print(example2())
1 Like

Hey @sschaer,

I just wanted to say thanks for your creative idea. Your workaround using @numba.extending.overload and @numba.njit could be a way for handling dynamic output dimensions. Your input is greatly appreciated!
For my personal taste, passing the output array and modifying its content seems to be a more straightforward way as long as Numba does not support other techniques.