Hi! I’m working on Numba-accelerated simulations and have encountered a situation where the user generates an arbitrary function that is given as an argument to a jitted function. This function should be able to pass this function to both jitted code and code within numba.objmode. In my case the non-jitted code consists of various integrators such as scipy.integrate.odeint and scipy.integrate.solve_ivp. Since numba.objmode does not yet support function arguments, I have accomplished this by giving both the function and a pointer to it to the jitted function. The jitted code calls the function, and the pointer is given to numba.objmode, within which it is converted to a function using ctypes.CFUNCTYPE, or the implementation of the function is fetched from a Python dictionary using the pointer as a key. How to do this so that giving the function once as an argument would be sufficient? As far as I’m aware, accomplishing any of these would solve the problem:
- Being able to convert a function to a pointer within jitted code as in the Numba feature request #7974.
- Being able to convert a pointer to a function within jitted code with ctypes.CFUNCTYPE or something similar.
- Being able to pass function arguments directly to numba.objmode, which is not yet supported by Numba.
- Jitting scipy.integrate.odeint and scipy.integrate.solve_ivp. This would be a great feature for numba-scipy.
Minimal example:
import ctypes
import numba
def gen_func_njit(a):
"""Generates a user-defined function with the given parameters.
The user should be able to replace this with an arbitrary function.
"""
@numba.njit
def func(x):
return a + x
return func
def gen_func_cfunc(a):
@numba.cfunc(numba.types.int64(numba.types.int64))
def func(x):
return a + x
return func
@numba.njit
def jitted(func):
"""Logic that uses the user-defined function"""
return func(1)
def not_jitted(func):
"""This cannot be jitted, as this could be e.g. scipy.integrate.odeint"""
return func(1)
@numba.njit
def wrapper(func):
"""This fails with:
Failed in nopython mode pipeline (step: nopython frontend)
Does not support function type inputs into with-context for arg 1
"""
ret1 = jitted(func)
with numba.objmode(ret2=numba.types.int64):
ret2 = not_jitted(func)
return ret1 + ret2
@numba.njit
def wrapper_hack(func, func_ptr):
"""How to get this to work so that the custom function needs to be given only once?"""
ret1 = jitted(func)
with numba.objmode(ret2=numba.types.int64):
func2 = ctypes.CFUNCTYPE(ctypes.c_int64, ctypes.c_int64)(func_ptr)
ret2 = not_jitted(func2)
return ret1 + ret2
def main():
func_cfunc = gen_func_cfunc(a=1)
print("This works")
wrapper_hack(func_cfunc, func_cfunc.address)
print("But this does not")
func_numba = gen_func_njit(a=1)
wrapper(func_numba)
if __name__ == "__main__":
main()