Will the matmul method be implemented in jitclass?

Will the __matmul__ method be implemented in jitclass in a future release ?

1 Like

If you edit the list of supported dunders to include it, does it “just work”:

?

Thanks a lot, it does indeed work ! (for reflected dunders __r...__ too)

Just out of curiosity/ignorance, what is needed for dunders to be officially supported ?

It would need the above change making to the dunders list, plus an addition to the tests to test it and an update to the documentation - would you be keen to make a PR that makes this change an adds tests / docs?

You should be able to just augment the existing jitclass dunder method tests with an extra case for __matmul__, but let me know if you attempt this and are having trouble seeing where to make the addition.

1 Like

What would be needed to do this? I am interested in helping if I can. This functionality would help me a lot so Id like to see it implemented.

@jsbryaniv There’s a PR from @louisamand here: Reflected dunder methods for jitclass by louisamand · Pull Request #8863 · numba/numba · GitHub - you could have a look and see if you can suggest a way to resolve the issue with it described in this comment

@louisamand it might also be worth breaking that PR into two - one for __matmul__ and one for __r*__ to make it easier to finish / progress.

The issue does not seem to be with the __matmul__ dunder directly, but actually the ‘@’ operator within numba. I can create a minimal failure example for __add__ if I redefine addition like this


        spec = {"x": types.Array(float64, 2, 'C')}
        @jitclass(spec)
        class MinFail:
            def __init__(self, x):
                self.x = x
            
            def __add__(self, other):
                return self.x @ other.x
            
        testvals = float64(np.arange(4).reshape(2, 2))
        a = MinFail(testvals)
        b = MinFail(testvals)
        c = a + b
        print(c)

Additionally I can create a working __matmul__ like this

        spec = {"x": types.Array(float64, 2, 'C')}
        @jitclass(spec)
        class MinFail:
            def __init__(self, x):
                self.x = x
            
            def __matmul__(self, other):
                output = np.zeros((self.x.shape[0], other.x.shape[1]))
                for i in range(self.x.shape[0]):
                    for k in range(self.x.shape[1]):
                        for j in range(other.x.shape[1]):
                            output[i, j] += self.x[i, k] * other.x[k, j]
                return output
            
        testvals = float64(np.arange(4).reshape(2, 2))
        a = MinFail(testvals)
        b = MinFail(testvals)
        c = a @ b
        print(c)