How to overload __mul__ (dunder method) in jitclass?

I want to check argument type in mul and depending on the type of argument at compile time, substitute the necessary function for jit

the code does not work:

import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex

@nbexp.jitclass
class Test:
    def __init__(self):
        ...
        
    def __mul__(self, other): 
        ...

@nbex.overload_classmethod(Test, "__mul__")
def test_mul(self, other):
    print('WE ARE HERE')
        
@nb.njit(nogil=True)
def run_test():
    return Test() * Test()

print( run_test() )

looks like overload of any method of jitclass also not working

from numba import types as nbt
import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex

@nbexp.jitclass
class Test:
    def __init__(self):
        ...
    def some(self, other) -> None:
        ...
    
def Test_some_impl1(self, other):
    return other

@nbex.overload_method(nbt.misc.ClassInstanceType, "some")
def over_some(self, other):
    if self is Test.class_type.instance_type:
        if other in nbt.number_domain:
            return Test_some_impl1

@nb.njit(nogil=True)
def run_test():
    return Test().some(2)

print( run_test() )

error:

AssertionError: Failed in nopython mode pipeline (step: native lowering)
('i64', 'i8*')

finally got it working

import numba as nb
import numba.experimental as nbexp
import numba.extending as nbex
from numba import types as nbt

@nbexp.jitclass([ ('_x', nbt.float32),
                  ('_y', nbt.float32), ])
class Vec2:
    def __init__(self, x : float, y : float):
        self._x = x
        self._y = y

    @property
    def x(self) -> float: return self._x
    @property
    def y(self) -> float: return self._y

    def __mul__(self, other): return Vec2(0,0) # overloaded


# Overload implementations
def Vec2__mul__Vec2(self, other):   return Vec2(self._x*other._x, self._y*other._y)
def Vec2__mul__number(self, other): return Vec2(self._x*float(other), self._y*float(other))

# Overloaders
@nbex.overload_method(nbt.misc.ClassInstanceType, "__mul__")
def over_Vec2__mul__(self, other):
    if self is Vec2.class_type.instance_type:
        if other is Vec2.class_type.instance_type: return Vec2__mul__Vec2
        if other in nbt.number_domain: return Vec2__mul__number

# Tests

@nb.njit(nogil=True)
def run_test1():
    return Vec2(1,1) * 2

@nb.njit(nogil=True)
def run_test2():
    return Vec2(1,1) * Vec2(3,3)

print( run_test1().x ) # outputs 2.0
print( run_test2().x ) # outputs 3.0

can it be simplified?

1 Like