How do I dynamically call a function?

Hello,

I’d like to dynamically call a function based on an integer (fct_selector here).
I have several Python modules, each has a different compute function.

FCT_MODULES = [
a_first_module,
a_second_module,

]
@njit
def dispatch(fct_selector: int) → Any:
return FCT_MODULES[selector].compute()

How can I do this ?

Thank you

You can store first-class jitted functions into a typed list and then access the list with an integer index.

Thank you for your response.

I am trying this :
COMPUTE_DOMAINS_FUNCTIONS = List.empty_list(numba.typeof(compute_domains_affine_eq))
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_affine_eq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_affine_geq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_affine_leq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_alldifferent)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_count_eq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_dummy)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_exactly_eq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_lexicographic_leq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_max_eq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_max_leq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_min_eq)
COMPUTE_DOMAINS_FUNCTIONS.append(compute_domains_min_geq)

With (for example) :
@njit(“int64(int32[::1,:], int32[:])”, cache=True)
def compute_domains_dummy(domains: NDArray, data: NDArray) → int:
“”"
A propagator that does nothing.
“”"
return PROP_CONSISTENCY

Of course, all functions have the same signature.

But it does not work :
NumbaTypeSafetyWarning: unsafe cast from type(CPUDispatcher(<function compute_domains_affine_geq at 0x10f62bce0>)) to type(CPUDispatcher(<function compute_domains_affine_eq at 0x10f62bba0>)). Precision may be lost.
l.append(item)

I could not make it work with a global constant for the list.

But I could make it work with:
def init_compute_domains_functions() → List[Callable]:
functions = List.empty_list(types.int64(int32[::1, :], int32[:]).as_type())
functions.append(compute_domains_affine_eq)
functions.append(compute_domains_affine_geq)
functions.append(compute_domains_affine_leq)
functions.append(compute_domains_alldifferent)
functions.append(compute_domains_count_eq)
functions.append(compute_domains_dummy)
functions.append(compute_domains_exactly_eq)
functions.append(compute_domains_lexicographic_leq)
functions.append(compute_domains_max_eq)
functions.append(compute_domains_max_leq)
functions.append(compute_domains_min_eq)
functions.append(compute_domains_min_geq)
return functions

And passing this list to the jitted function that dispatches and executes the compute_domain functions.

BUT it is significantly slower.
I profiled the Python (with NUMBA_DISABLE_JIT=1) and nothing changed.
Could it be the case that accessing the ith element of the type list is slow ?

If the work being done in the function is small, the typed list overhead could be significant.

There are other techniques, like storing pointers to the functions in a numpy array or simply using an if-else construct, which may well be fastest.

Note that first-class functions return zero-initialized values if an exception is thrown. That is, exceptions are not propagated

Hello,

Thank you for your answer.
I was using the if-else already, I’ll give the numpy array of function pointers a try.

Cheers
Yan

Hello,

How can I, from a jitted function, call another jitted function with a pointer ?

Thank you,
Yan

See this post

TLDR. I usually use a custom intrinsic, not sure if the devs have added a standard way since that post.

Thank you very much, this is what I was looking for.

Of course my case was slightly more complex because I have an array of functions that I want to call based on a dynamic index but I was able to make it work !

Cheers,
Yan

@DannyWeitekamp Unfortunately Any numba equivalent for casting a raw pointer to a StructRef, Dict, List etc? - #24 by DannyWeitekamp does not seem to work anymore with Numba 0.61.0rc2, any idea why ?

@Yangeorget Perhaps you can provide the error you get when trying this example to help locate the source of the issue?

Hello,

To give some context, I am building a numpy array of jit-compiled functions addresses.
I use these addresses to dynamically call the functions.

(venv-3.13) ➜ nucs git:(main) ✗ cat requirements.txt
numba==0.61.0rc2
numpy==2.1.3

(venv-3.13) ➜ nucs git:(main) ✗ cat nucs/numba_helper.py
from typing import List

from numba import types # type: ignore
from numba.core import cgutils
from numba.experimental.function_type import _get_wrapper_address
from numba.extending import intrinsic

@intrinsic
def function_from_address(typingctx, func_type_ref: types.FunctionType, addr: int): # type: ignore
“”"
Recovers a function from FunctionType and address.
“”"
func_type = func_type_ref.instance_type

def codegen(context, builder, sig, args):  # type: ignore
    _, address = args
    sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
    sfunc.addr = builder.inttoptr(address, context.get_value_type(types.voidptr))
    return sfunc._getvalue()

return func_type(func_type_ref, addr), codegen

def build_function_address_list(functions, signature) → List[int]: # type: ignore
return [_get_wrapper_address(function, signature) for function in functions]

When I run some tests, I get:

nucs/numba_helper.py:38: in build_function_address_list
return [_get_wrapper_address(function, signature) for function in functions]
venv-3.13/lib/python3.13/site-packages/numba/experimental/function_type.py:159: in _get_wrapper_address
cres = func.get_compile_result(sig)
venv-3.13/lib/python3.13/site-packages/numba/core/dispatcher.py:925: in get_compile_result
self.compile(atypes)
venv-3.13/lib/python3.13/site-packages/numba/core/dispatcher.py:904: in compile
cres = self._compiler.compile(args, return_type)
venv-3.13/lib/python3.13/site-packages/numba/core/dispatcher.py:80: in compile
status, retval = self._compile_cached(args, return_type)
venv-3.13/lib/python3.13/site-packages/numba/core/dispatcher.py:94: in _compile_cached
retval = self._compile_core(args, return_type)
venv-3.13/lib/python3.13/site-packages/numba/core/dispatcher.py:107: in _compile_core
cres = compiler.compile_extra(self.targetdescr.typing_context,
venv-3.13/lib/python3.13/site-packages/numba/core/compiler.py:739: in compile_extra
return pipeline.compile_extra(func)
venv-3.13/lib/python3.13/site-packages/numba/core/compiler.py:439: in compile_extra
return self._compile_bytecode()
venv-3.13/lib/python3.13/site-packages/numba/core/compiler.py:505: in _compile_bytecode
return self._compile_core()
venv-3.13/lib/python3.13/site-packages/numba/core/compiler.py:481: in _compile_core
raise e
venv-3.13/lib/python3.13/site-packages/numba/core/compiler.py:473: in _compile_core
pm.run(self.state)
venv-3.13/lib/python3.13/site-packages/numba/core/compiler_machinery.py:363: in run
raise e
venv-3.13/lib/python3.13/site-packages/numba/core/compiler_machinery.py:356: in run
self._runPass(idx, pass_inst, state)
venv-3.13/lib/python3.13/site-packages/numba/core/compiler_lock.py:35: in _acquire_compile_lock
return func(*args, **kwargs)
venv-3.13/lib/python3.13/site-packages/numba/core/compiler_machinery.py:311: in _runPass
mutated |= check(pss.run_pass, internal_state)
venv-3.13/lib/python3.13/site-packages/numba/core/compiler_machinery.py:272: in check
mangled = func(compiler_state)
venv-3.13/lib/python3.13/site-packages/numba/core/typed_passes.py:468: in run_pass
lower.lower()
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:193: in lower
self.lower_normal_function(self.fndesc)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:232: in lower_normal_function
entry_block_tail = self.lower_function_body()
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:262: in lower_function_body
self.lower_block(block)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:276: in lower_block
self.lower_inst(inst)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:462: in lower_inst
val = self.lower_assign(ty, inst)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:674: in lower_assign
return self.lower_expr(ty, value)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:1268: in lower_expr
res = self.lower_call(resty, expr)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:939: in lower_call
res = self._lower_call_normal(fnty, expr, signature)
venv-3.13/lib/python3.13/site-packages/numba/core/lowering.py:1239: in _lower_call_normal
res = impl(self.builder, argvals, self.loc)
venv-3.13/lib/python3.13/site-packages/numba/core/base.py:1190: in call
res = self._imp(self._context, builder, self._sig, args, loc=loc)
venv-3.13/lib/python3.13/site-packages/numba/core/base.py:1220: in wrapper
return fn(*args, **kwargs)
nucs/numba_helper.py:31: in codegen
sfunc.addr = builder.inttoptr(address, context.get_value_type(types.voidptr))
venv-3.13/lib/python3.13/site-packages/numba/core/cgutils.py:164: in setattr
self[self._datamodel.get_field_position(field)] = value


self = <numba.experimental.function_type.FunctionModel object at 0x104e6ee40>, field = ‘addr’

def get_field_position(self, field):
    try:
        return self._fields.index(field)
    except ValueError:
      raise KeyError("%s does not have a field named %r"
                       % (self.__class__.__name__, field))

E KeyError: “FunctionModel does not have a field named ‘addr’”

venv-3.13/lib/python3.13/site-packages/numba/core/datamodel/old_models.py:678: KeyError

Any help would be much appreciated !
Thanks,
Yan

From a quick skim: The addr field seems to have been renamed to c_addr. In @DannyWeitekamp 's example, the bit that says sfunc.addr = addr_ptr might need changing to sfunc.c_addr = addr_ptr.

(This is a guess based on a quick skim of the changes to the model, so let me know if there’s something else that goes wrong with this change or if it doesn’t fit in with what you have)

1 Like

That’s perfect ! Thank you very much.

Just by curiosity, is it still the “best” way to call a jit-compiled function from its address ?

Yan

@gmarkall Many thanks for catching that. Glad it was an easy fix. I’ll go ahead and update the code in the post that I linked to above with this fix (in case someone else ends up there first). (Woops looks like it’s too old to do that… perhaps someone with admin access could). I suppose these kinds of issues are inevitable with workarounds like these that get really deep into the guts of things.

@yangeorget I think the “official” way to do this would be to use numba’s typed List filled with first-class functions (which would have the requirement of needing the same signature). Getting that to work—quite frankly—takes a similar amount of type finessing as what you have and unfortunately would introduce a non-negligiable amount of overhead both on the Python side to prepare and pass the List and on the numba side in the performance critical parts of your code—at least its quite a bit of overhead relative to, let’s say doing the same thing in C/C++. A while back I had a very detailed discussion here with @esc about the sources of this overhead. If I had time I’d write a PR… but alas.

If you’re using the workaround with numpy arrays, on the other hand, you’re pretty darn close to the fastest way I could think of implementing this in C/C++ (just with fewer compile-time guarantees from type checking). I suspect that this approach, while decidedly fast, will stay in non-standard land for quite a while. Treating an int64 or uint64 as a weak untyped pointer is all well and good if you know what you’re doing, but is nightmare fuel for a dev team trying to make sure numba users never get a segfault in place of a proper error message.

1 Like

Thank you @DannyWeitekamp for these explanations. I’ll stick with the current solution as it completely satisfactory.

Thanks again to both of you.