I am encountering a (type inference?) problem that can be reproduced by the minimal example below. The problem arises in my setting when passing njit compiled functions as arguments. All calls occur in a controlled manner and without the signatures I had very high JIT-compilation times.
`from numba import njit, float64, bool_
from numba.core.types import FunctionType, UniTuple, Tuple
@njit([float64(float64)])
def f(x):
return x
@njit([float64(float64, bool_)])
def g(x, b):
return x if b else 0.0
@njit([float64(FunctionType(float64(float64)), float64, Tuple(())),
float64(FunctionType(float64(float64, bool_)), bool_, UniTuple(bool_, 1))])
def h(fn, x, *args):
return fn(x, *args)
h(g, 1.0, *(True,)) # Will fail
h(f, 1.0, *()) # will fail`
I followed the calls into numba and the problem seems to happen in numba.core.types.functions.Dispatcher.can_convert_to which at some point tries to compile code for the wrong signature (since type inference did not exclude it) – or in this case a function no longer allowing compilation.
From what I see all types and calls are correct, and dispatching would work if the can_convert functions just returned False instead of raising RuntimeException.
The problem does not occur when compilation is allowed (which is the problem I am trying to solve) but at the cost of high compilation times and the added signatures in h for very specific functions (i.e. CPUDispatcher bound to specific function, and not as I want it to be bound to the function type, not the pointer)
Am I using the right signatures? Am I using the signatures feature in a wrong manner? Why is the type of an njitted function type(Dispatcher f at mem_addr) and not the FunctionType by the way?
Any hint would be greatly appreciated!
Amine.