I’m trying to write a jitclass which has a function pointer which will select which member function to call. I’m currently stuck at what type annotation I should put for the function pointer, which points to a member function. A minimum example is shown below:
nb.experimental.jitclass
class Foo:
a: float
f: <what type to put here?>
def __init__(self, a):
if a > 1e-6:
self.f = func1
else:
self.f = func2
def func1(self, x):
return x * x
def func2(self, x):
return x * x + self.a
This issue is somewhat similar to numba.discourse.group/t/function-pointer-inside-a-jitclass/903 from 2 years ago. For that issue, because both pointed functions rely no internal state of the jitclass, we can do a workaround by moving them outside the jitclass:
nb.experimental.jitclass
class Foo:
a: float
f: nb.float64(nb.float64).as_type()
def __init__(self, a):
if a > 1e-6:
self.f = func1
else:
self.f = func2
@nb.njit
def func1(x):
return x * x
@nb.njit
def func2(x):
return x + x
foo = Foo(0)
print(foo.f(5))