Coercing literal types for dispatch

Hi all! In a project which I maintain, I make extensive use of literal types to dispatch specific implementations of certain functions. Unfortunately, many of the functions take a large number of arguments, and the documented approach to literal dispatch (using @overload), often requires stamping out the function signature over and over again. An example of how this may look (based on the example in the Notes on Literal Types section of the docs), is as follows:

from numba import njit, literally
from numba.core.types import Literal
from numba.extending import overload


@njit
def test_f(arg1, arg2, arg3, mode):
    return f(arg1, arg2, arg3, mode)


def f(arg1, arg2, arg3, mode):
    raise NotImplementedError


@overload(f)
def nb_f(arg1, arg2, arg3, mode):

    if isinstance(mode, Literal):
        if mode.literal_value == 0:
            def impl(arg1, arg2, arg3, mode):
                return arg1
        elif mode.literal_value == 1:
            def impl(arg1, arg2, arg3, mode):
                return arg2
        else:
            def impl(arg1, arg2, arg3, mode):
                return arg3
    else:
        return lambda arg1, arg2, arg3, mode: literally(mode)

    return impl


print(test_f(1, 2, 3, 3))

This example isn’t too problematic - there are only four arguments and the function and variable names are short. In more practical applications this may not be the case. In particular, the lambda function may become quite cumbersome. In an effort to keep things a little tidier, I have been using the following (noting that coerce_literal will typically be imported from elsewhere):

import inspect
from numba import njit
from numba.extending import overload, SentryLiteralArgs


def coerce_literal(func, literals):
    func_locals = inspect.currentframe().f_back.f_locals  # One frame up.
    arg_types = [func_locals[k] for k in inspect.signature(func).parameters]
    SentryLiteralArgs(literals).for_function(func).bind(*arg_types)


@njit
def test_f(arg1, arg2, arg3, mode):
    return f(arg1, arg2, arg3, mode)


def f(arg1, arg2, arg3, mode):
    raise NotImplementedError


@overload(f)
def nb_f(arg1, arg2, arg3, mode):

    coerce_literal(nb_f, ["mode"])

    if mode.literal_value == 0:
        def impl(arg1, arg2, arg3, mode):
            return arg1
    elif mode.literal_value == 1:
        def impl(arg1, arg2, arg3, mode):
            return arg2
    else:
        def impl(arg1, arg2, arg3, mode):
            return arg3

    return impl


print(test_f(1, 2, 3, 1))

My question is whether this is a sensible approach or if it may interact with compilation in a way I don’t understand? Thanks in advance!

Hi @JSKenyon

I don’t have an answer to your actual question, but I once faced a somewhat related problem in one of my projects. I needed fine-grained dispatching (often based on literal values) to minimize compilation times. As you said, this can result in very convoluted code. I ended up writing a custom overload decorator that allows you to do this in a more linear fashion. Your example would then look like this:

from numba import njit, literally
from numba import types 

import rocket_fft.imputils as iu
import rocket_fft.typutils as tu 

def is_literal(arg):
    return isinstance(arg, types.Literal)

@njit
def test_f(arg1, arg2, arg3, mode):
    return f(arg1, arg2, arg3, mode)

@iu.implements_jit(prefer_literal=True, strict=False)
def f(arg1, arg2, arg3, mode):
    ...

@f.impl(mode=tu.is_literal_integer(0))
def _(_1, _2, _3, mode):
    return _1 

@f.impl(mode=tu.is_literal_integer(1))
def _(_1, _2, _3, mode):
    return _2

@f.impl(mode=is_literal)
def _(_1, _2, _3, mode):
    return _3

@f.impl(iu.otherwise)
def _(_1, _2, _3, mode):
    return literally(mode)

test_f(1, 2, 3, 1)

A bit more complex dispatch pattern can be found, for example, here.

Also, note that with numba.extending.overload, you can let Numba test literal values first, and you can also allow different argument names. Especially the latter may be of interest to you.

So maybe writing a wrapper for overload is also an option for you.

Thanks! That does look interesting and I would also be very interested in reducing compile times. It isn’t a huge problem in normal use but it does become problematic for testing. I will definitely give this a shot.