iperov
September 21, 2023, 10:19am
1
I want to check argument type in mul and depending on the type of argument at compile time, substitute the necessary function for jit
the code does not work:
import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex
@nbexp.jitclass
class Test:
def __init__(self):
...
def __mul__(self, other):
...
@nbex.overload_classmethod(Test, "__mul__")
def test_mul(self, other):
print('WE ARE HERE')
@nb.njit(nogil=True)
def run_test():
return Test() * Test()
print( run_test() )
iperov
September 21, 2023, 2:38pm
2
looks like overload of any method of jitclass also not working
from numba import types as nbt
import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex
@nbexp.jitclass
class Test:
def __init__(self):
...
def some(self, other) -> None:
...
def Test_some_impl1(self, other):
return other
@nbex.overload_method(nbt.misc.ClassInstanceType, "some")
def over_some(self, other):
if self is Test.class_type.instance_type:
if other in nbt.number_domain:
return Test_some_impl1
@nb.njit(nogil=True)
def run_test():
return Test().some(2)
print( run_test() )
error:
AssertionError: Failed in nopython mode pipeline (step: native lowering)
('i64', 'i8*')
iperov
September 21, 2023, 3:17pm
3
finally got it working
import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex
from numba import types as nbt
@nbexp.jitclass([ ('_x', nbt.float32),
('_y', nbt.float32), ])
class Vec2:
def __init__(self, x : float, y : float):
self._x = x
self._y = y
@property
def x(self) -> float: return self._x
@property
def y(self) -> float: return self._y
def __mul__(self, other): return Vec2(0,0) # overloaded
# Overload implementations
def Vec2__mul__Vec2(self, other): return Vec2(self._x*other._x, self._y*other._y)
def Vec2__mul__number(self, other): return Vec2(self._x*float(other), self._y*float(other))
# Overloaders
@nbex.overload_method(nbt.misc.ClassInstanceType, "__mul__")
def over_Vec2__mul__(self, other):
if self is Vec2.class_type.instance_type:
if other is Vec2.class_type.instance_type: return Vec2__mul__Vec2
if other in nbt.number_domain: return Vec2__mul__number
# Tests
@nb.njit(nogil=True)
def run_test1():
return Vec2(1,1) * 2
@nb.njit(nogil=True)
def run_test2():
return Vec2(1,1) * Vec2(3,3)
print( run_test1().x ) # outputs 2.0
print( run_test2().x ) # outputs 3.0
can it be simplified?
1 Like