I have not tried with a C compiled function, but I suspect that this would work:
from numba.extending import intrinsic
from numba.core import cgutils
from numba.types import i8
from numba import types, cfunc, njit
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
'''Recovers a function from it's signature and address '''
func_type = func_type_ref.instance_type
def codegen(context, builder, sig, args):
_, addr = args
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
llty = context.get_value_type(types.voidptr)
addr_ptr = builder.inttoptr(addr,llty)
sfunc.addr = addr_ptr
return sfunc._getvalue()
sig = func_type(func_type_ref, addr)
return sig, codegen
my_c_func_type = types.FunctionType(i8(i8,i8))
@njit(i8(i8, i8, i8), cache=True)
def njit_forward(a, b, func_addr):
f = _func_from_address(my_c_func_type, func_addr)
return f(a,b)
@cfunc(i8(i8, i8, i8), nopython=True, cache=True)
def forward(a, b, func_addr):
return njit_forward(a, b, func_addr)
It seems calling a predefined njit_forward() inside the @cfunc implementation of forward() is necessary, because @cfunc doesn’t seem work if you call an intrinsic inside of it. I find this odd, but I imagine this is just a bug in numba. Perhaps you can fiddle with this to make something more concise.
Below is an example of how you would use this with the address of a normal jitted function. I suspect any C defined function would work also so long as you keep to numerical types. Probably also would work with pointer types, although I suspect that if everything was compiled for 64-bit addresses a C function probably wouldn’t complain if you passed it an i8 or u8 in place of a pointer. Someone more knowledgeable may have a good reason for why this is a bad idea.
from numba.experimental.function_type import _get_wrapper_address
@njit(cache=True)
def foo(a,b):
return a * b
jit_addr = _get_wrapper_address(foo, i8(i8,i8))
print(forward(4, 5, jit_addr))