Will the __matmul__
method be implemented in jitclass
in a future release ?
If you edit the list of supported dunders to include it, does it “just work”:
?
Thanks a lot, it does indeed work ! (for reflected dunders __r...__
too)
Just out of curiosity/ignorance, what is needed for dunders to be officially supported ?
It would need the above change making to the dunders list, plus an addition to the tests to test it and an update to the documentation - would you be keen to make a PR that makes this change an adds tests / docs?
You should be able to just augment the existing jitclass dunder method tests with an extra case for __matmul__
, but let me know if you attempt this and are having trouble seeing where to make the addition.
What would be needed to do this? I am interested in helping if I can. This functionality would help me a lot so Id like to see it implemented.
@jsbryaniv There’s a PR from @louisamand here: Reflected dunder methods for jitclass by louisamand · Pull Request #8863 · numba/numba · GitHub - you could have a look and see if you can suggest a way to resolve the issue with it described in this comment
@louisamand it might also be worth breaking that PR into two - one for __matmul__
and one for __r*__
to make it easier to finish / progress.
The issue does not seem to be with the __matmul__
dunder directly, but actually the ‘@’ operator within numba. I can create a minimal failure example for __add__
if I redefine addition like this
spec = {"x": types.Array(float64, 2, 'C')}
@jitclass(spec)
class MinFail:
def __init__(self, x):
self.x = x
def __add__(self, other):
return self.x @ other.x
testvals = float64(np.arange(4).reshape(2, 2))
a = MinFail(testvals)
b = MinFail(testvals)
c = a + b
print(c)
Additionally I can create a working __matmul__
like this
spec = {"x": types.Array(float64, 2, 'C')}
@jitclass(spec)
class MinFail:
def __init__(self, x):
self.x = x
def __matmul__(self, other):
output = np.zeros((self.x.shape[0], other.x.shape[1]))
for i in range(self.x.shape[0]):
for k in range(self.x.shape[1]):
for j in range(other.x.shape[1]):
output[i, j] += self.x[i, k] * other.x[k, j]
return output
testvals = float64(np.arange(4).reshape(2, 2))
a = MinFail(testvals)
b = MinFail(testvals)
c = a @ b
print(c)