I often want a function to be able to take arguments of different types. As a simple example:
def sqrt(value: Union[float, List[float]]):
if isinstance(value, float):
return math.sqrt(value)
value:
return [math.sqrt(x) for x in value]
(I know that this specific example is best solved using np.sqrt
.)
This function doesn’t work in @njit
ed code, because isinstance
only works in object mode. Is there a good workaround? Currently what I’ve been doing is doing separate functions (e.g. sqrt_float
and sqrt_list
), but that doesn’t work for certain special functions, such as __init__
and __getitem__
. Is there some sort of workaround for the lack of isinstance
, especially that would work in @jitclass
es?
Does overload work for your situation?
I thought that overload replaces a function entirely. Can I overload another python function and have both work, depending on the type of argument?
It looks like generated_jit should work. So something like this should work?:
@generated_jit
def sqrt(x):
if isinstance(x, types.Float):
return lambda x: math.sqrt(x)
else:
return lambda x: [sqrt(xv) for xv in x]
Will it work in a @jitclass
?
the use case for overload
and generated_jit
are slightly different.
overload
will keep a python-only version that is always used when calling the function from pure python. The “overloaded” versions are only used when called from another jitted function. In this cases (when running inside jit
), the function body will select the relevant implementation according to the argument types.
def sqrt(x):
return math.sqrt(x)
@overload
def sqrt(x):
if isinstance(x, types.Float):
return lambda x: math.sqrt(x)
else:
return lambda x: [sqrt(xv) for xv in x]
generated_jit
does something similar but it will use the jitted version even when being called from pure python.
Regarding jitclass
, yes, both will work but it depends how you use it. I don’t think you can use generated_jit
or overload
inside the jitclass but you can call those functions from the jitclass
.
@jitclass(....)
class MyClass:
def calc_sqrt(x):
return sqrt(x)
Luk