Using numba's Multiple Dispatch to call different jit-ed functions?

Is there a way to take advantage of the multiple dispatch returned by numba.jit to dispatch to different jit-ed functions?

I’m faced with a situation where the function’s argument can either be a scalar of an array at runtime and would require slightly different function definitions. But I don’t want the user to keep track of which function to call and want to use multiple dispatch to call the right function.

The example below is a toy example, but the idea is that I always have a for-loop that iterates through the first argument that is known to be a 1D array, but the other arguments can either be arrays with the same dimensionality, or it could be a scalar.

@numba.njit()
def func(always_arr: Iterable[Number], arr_or_scalar: Iterable[Number], out_arr: Iterable[Number]):
    for n in range(always_arr.shape[0]):
        out_arr[n] = always_arr[n]+arr_or_scalar[n]

@func.register([(double[:], double, double[:])])
def _(always_arr: Iterable[Number], arr_or_scalar: Number, out_arr: Iterable[Number]):
    for n in range(a.shape[0]):
        out_arr[n] = always_arr[n]+arr_or_scalar

I could use a custom outer multipledispatch but if Numba’s is already doing multiple dispatch it’d be ideal if I could just use the Numba’s implementation.

Thanks!!

hey @TK-21st

I think this might be what you are looking for: Flexible specializations with @generated_jit — Numba 0.55.0+0.gd44b8f446.dirty-py3.7-linux-x86_64.egg documentation

Luk

Ah, yes indeed! Thanks @luk-f-a!

For anyone looking, the support of this feature for CUDA is being tracked in Support generated_jit for cuda target · Issue #2754 · numba/numba · GitHub and PR generated_jit for CUDA kernels by sjperkins · Pull Request #3450 · numba/numba · GitHub