Is there a way to take advantage of the multiple dispatch returned by numba.jit
to dispatch to different jit-ed functions?
I’m faced with a situation where the function’s argument can either be a scalar of an array at runtime and would require slightly different function definitions. But I don’t want the user to keep track of which function to call and want to use multiple dispatch to call the right function.
The example below is a toy example, but the idea is that I always have a for-loop that iterates through the first argument that is known to be a 1D array, but the other arguments can either be arrays with the same dimensionality, or it could be a scalar.
@numba.njit()
def func(always_arr: Iterable[Number], arr_or_scalar: Iterable[Number], out_arr: Iterable[Number]):
for n in range(always_arr.shape[0]):
out_arr[n] = always_arr[n]+arr_or_scalar[n]
@func.register([(double[:], double, double[:])])
def _(always_arr: Iterable[Number], arr_or_scalar: Number, out_arr: Iterable[Number]):
for n in range(a.shape[0]):
out_arr[n] = always_arr[n]+arr_or_scalar
I could use a custom outer multipledispatch but if Numba’s is already doing multiple dispatch it’d be ideal if I could just use the Numba’s implementation.
Thanks!!