Hey @seekiu ,
Unfortunately, I can’t answer your question. I don’t know how to define a function pointer within a jitclass. Nevertheless, this is an interesting topic and I would like to add some content regarding your question if you don’t mind.
The straightforward way to generate the jitclass would be to implement the Python class with type annotations and simply convert it.
To help inferring the data types we can provide the field specs.
Unfortunately, there seems to be a problem casting the FunctionType
into a BoundFunction
type.
I would assume that the field spec for the BoundFunction
type has to be adjusted to solve the problem.
# =============================================================================
# Try to convert Python class into jitclass using FunctionType in spec
# Error: Cannot cast BoundFunction[...] to FunctionType[...]
# =============================================================================
from collections.abc import Callable
import numba as nb
import numba.types as nbt
class Foo:
a: float
f: Callable[[float], float]
def __init__(self, a):
self.a = a
self.f = (self.func1 if a > 1e-6 else self.func2)
def func1(self, x: float) -> float:
return x * x
def func2(self, x: float) -> float:
return x + x + self.a
foo1 = Foo(a=1.)
foo2 = Foo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001
spec = [('a', nbt.f8),
('f', nbt.FunctionType(nbt.f8(nbt.f8)))]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# TypingError: Failed in nopython mode pipeline (step: native lowering)
# Cannot cast BoundFunction[...] to FunctionType[...]
If we are not able to define the correct function type to make the conversion into a jitclass work we could do a workaround.
If the class can remain as a Python class and we just need the computation power of Numba in specific functions we could use a function factory to generate a jitted function. We can assign the jitted function to a function pointer in the Python class.
# =============================================================================
# Keep Python class and generate jitted function
# =============================================================================
from collections.abc import Callable
import numba as nb
import numba.types as nbt
def function_factory(a: float) -> Callable[[float], float]:
@nb.njit(nbt.f8(nbt.f8))
def func(x: float) -> float:
return x * x if a > 1e-6 else x + x + a
return func
class Foo:
def __init__(self, a: float):
self.a = a
self.f = function_factory(a)
foo1 = Foo(a=1.)
foo2 = Foo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001
If it has to be a jitclass, we don’t need a function pointer and there is no reason not to use external functions we can follow your suggestion.
# =============================================================================
# Generate Python class and convert to jitclass using external functions
# =============================================================================
import numba as nb
import numba.types as nbt
@nb.njit(nbt.f8(nbt.f8))
def func1(x: float) -> float:
return x * x
@nb.njit(nbt.f8(nbt.f8, nbt.f8))
def func2(x: float, a: float) -> float:
return x * x if a > 1e-6 else x + x + a
class Foo:
a: float
def __init__(self, a):
self.a = a
def f(self, x: float):
return (func1(x) if self.a > 1e-6
else func2(x, self.a))
spec = [('a', nbt.f8)]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001
If the functions have to be part of the class and we don’t need the function pointer we could generate the Python class, define the functions as static methods and simply convert it into a jitclass. We only have to specify the field a
for the conversion.
All methods of the jitclass should be automatically compiled into nopython functions.
# =============================================================================
# Generate Python class and convert to jitclass using staticmethods
# =============================================================================
import numba as nb
import numba.types as nbt
class Foo:
a: float
def __init__(self, a):
self.a = a
def f(self, x: float):
return (self.func1(x) if self.a > 1e-6
else self.func2(x, self.a))
@staticmethod
def func1(x: float) -> float:
return x * x
@staticmethod
def func2(x: float, a: float) -> float:
return x * x if a > 1e-6 else x + x + a
spec = [('a', nbt.f8)]
JitFoo = nb.experimental.jitclass(Foo, spec)
foo1 = JitFoo(a=1.)
foo2 = JitFoo(a=1e-7)
print(f'foo1: {foo1.f(10.0)}')
print(f'foo2: {foo2.f(10.0)}')
# foo1: 100.0
# foo2: 20.0000001
I am curious to see the solution to your problem.