How to pass a jitted function to both jitted code and numba.objmode from within a jitted function?

Hi! I’m working on Numba-accelerated simulations and have encountered a situation where the user generates an arbitrary function that is given as an argument to a jitted function. This function should be able to pass this function to both jitted code and code within numba.objmode. In my case the non-jitted code consists of various integrators such as scipy.integrate.odeint and scipy.integrate.solve_ivp. Since numba.objmode does not yet support function arguments, I have accomplished this by giving both the function and a pointer to it to the jitted function. The jitted code calls the function, and the pointer is given to numba.objmode, within which it is converted to a function using ctypes.CFUNCTYPE, or the implementation of the function is fetched from a Python dictionary using the pointer as a key. How to do this so that giving the function once as an argument would be sufficient? As far as I’m aware, accomplishing any of these would solve the problem:

  • Being able to convert a function to a pointer within jitted code as in the Numba feature request #7974.
  • Being able to convert a pointer to a function within jitted code with ctypes.CFUNCTYPE or something similar.
  • Being able to pass function arguments directly to numba.objmode, which is not yet supported by Numba.
  • Jitting scipy.integrate.odeint and scipy.integrate.solve_ivp. This would be a great feature for numba-scipy.

Minimal example:

import ctypes

import numba


def gen_func_njit(a):
    """Generates a user-defined function with the given parameters.
    The user should be able to replace this with an arbitrary function.
    """
    @numba.njit
    def func(x):
        return a + x
    return func


def gen_func_cfunc(a):
    @numba.cfunc(numba.types.int64(numba.types.int64))
    def func(x):
        return a + x
    return func


@numba.njit
def jitted(func):
    """Logic that uses the user-defined function"""
    return func(1)


def not_jitted(func):
    """This cannot be jitted, as this could be e.g. scipy.integrate.odeint"""
    return func(1)


@numba.njit
def wrapper(func):
    """This fails with:
    Failed in nopython mode pipeline (step: nopython frontend)
    Does not support function type inputs into with-context for arg 1
    """
    ret1 = jitted(func)
    with numba.objmode(ret2=numba.types.int64):
        ret2 = not_jitted(func)
    return ret1 + ret2


@numba.njit
def wrapper_hack(func, func_ptr):
    """How to get this to work so that the custom function needs to be given only once?"""
    ret1 = jitted(func)
    with numba.objmode(ret2=numba.types.int64):
        func2 = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_int64)(func_ptr)
        ret2 = not_jitted(func2)
    return ret1 + ret2


def main():
    func_cfunc = gen_func_cfunc(a=1)
    print("This works")
    wrapper_hack(func_cfunc, func_cfunc.address)
    print("But this does not")
    func_numba = gen_func_njit(a=1)
    wrapper(func_numba)


if __name__ == "__main__":
    main()