Hi all! In a project which I maintain, I make extensive use of literal types to dispatch specific implementations of certain functions. Unfortunately, many of the functions take a large number of arguments, and the documented approach to literal dispatch (using @overload), often requires stamping out the function signature over and over again. An example of how this may look (based on the example in the Notes on Literal Types section of the docs), is as follows:
from numba import njit, literally
from numba.core.types import Literal
from numba.extending import overload
@njit
def test_f(arg1, arg2, arg3, mode):
return f(arg1, arg2, arg3, mode)
def f(arg1, arg2, arg3, mode):
raise NotImplementedError
@overload(f)
def nb_f(arg1, arg2, arg3, mode):
if isinstance(mode, Literal):
if mode.literal_value == 0:
def impl(arg1, arg2, arg3, mode):
return arg1
elif mode.literal_value == 1:
def impl(arg1, arg2, arg3, mode):
return arg2
else:
def impl(arg1, arg2, arg3, mode):
return arg3
else:
return lambda arg1, arg2, arg3, mode: literally(mode)
return impl
print(test_f(1, 2, 3, 3))
This example isn’t too problematic - there are only four arguments and the function and variable names are short. In more practical applications this may not be the case. In particular, the lambda function may become quite cumbersome. In an effort to keep things a little tidier, I have been using the following (noting that coerce_literal
will typically be imported from elsewhere):
import inspect
from numba import njit
from numba.extending import overload, SentryLiteralArgs
def coerce_literal(func, literals):
func_locals = inspect.currentframe().f_back.f_locals # One frame up.
arg_types = [func_locals[k] for k in inspect.signature(func).parameters]
SentryLiteralArgs(literals).for_function(func).bind(*arg_types)
@njit
def test_f(arg1, arg2, arg3, mode):
return f(arg1, arg2, arg3, mode)
def f(arg1, arg2, arg3, mode):
raise NotImplementedError
@overload(f)
def nb_f(arg1, arg2, arg3, mode):
coerce_literal(nb_f, ["mode"])
if mode.literal_value == 0:
def impl(arg1, arg2, arg3, mode):
return arg1
elif mode.literal_value == 1:
def impl(arg1, arg2, arg3, mode):
return arg2
else:
def impl(arg1, arg2, arg3, mode):
return arg3
return impl
print(test_f(1, 2, 3, 1))
My question is whether this is a sensible approach or if it may interact with compilation in a way I don’t understand? Thanks in advance!