Help needed to re-implement np.matmul for 4D and 5D matrix

Hey again!

Still a complete newbie in Numba, I need to reimplement 4D and 5D matmul; but the 5D version are slower than classic numpy.

So far it is what I achieved

import numpy as np
from numba import njit, prange
from numba import cuda
import timeit

@njit
def matmul_4d(a, b):
    s1, s2, s3, s4 = a.shape
    t4 = b.shape[-1]

    c = np.zeros((s1, s2, s3, t4))

    for i in range(s1):
        for j in range(s2):
            for k in range(s3):
                for l in range(t4):
                    for m in range(s4):
                        c[i, j, k, l] += a[i, j, k, m] * b[m, l]

    return c

@njit(fastmath=True, parallel=True)
def matmul_check_collisions(a, b):
    s1, _, s3, s4, _ = a.shape
    _, t2, _, _, t5 = b.shape

    c = np.zeros((s1, t2, s3, s4, t5))

    for i in prange(s1):
        for k in range(s3):
            for j in range(t2):
                for l in range(s4):
                    a_row = a[i, 0, k, l]
                    for m in range(t5):
                        b_col = b[i, j, k, :, m]
                        sum_ = np.dot(a_row, b_col)
                        c[i, j, k, l, m] = sum_

    return c


@njit
def matmul_check_drivable(a, b):
    s1, s2, _, s4, _ = a.shape
    _, _, t3, _, t5 = b.shape

    c = np.zeros((s1, s2, t3, s4, t5))

    # Loop over the dimensions
    for i in range(s1):
        for j in range(s2):
            for k in range(t3):
                for l in range(s4):
                    a_contig = np.ascontiguousarray(a[i, j, l, 0, :])
                    b_contig = np.ascontiguousarray(b[i, j, k, :, 0])
                    c[i, j, k, l, 0] = np.dot(a_contig, b_contig)

    return c


if __name__ == "__main__":
    # Matmul 4D
    a = np.random.rand(1, 10, 2, 2)
    b = np.random.rand(2, 4)

    c = np.matmul(a, b)
    d = matmul_4d(a, b)

    print("Matmul 4D")
    print(np.allclose(c, d), c.shape == d.shape)
    print("Numpy: ", timeit.timeit(lambda: np.matmul(a, b), number=1000))
    print("Numba: ", timeit.timeit(lambda: matmul_4d(a, b), number=1000))

    # Matmul check collisions
    a = np.random.rand(1, 1, 10, 2, 2)
    b = np.random.rand(1, 23, 10, 2, 1)

    c = np.matmul(a, b)
    d = matmul_check_collisions(a, b)

    print("Matmul check collisions")
    print(np.allclose(c, d), c.shape == d.shape)
    print("Numpy: ", timeit.timeit(lambda: np.matmul(a, b), number=1000))
    print("Numba: ", timeit.timeit(lambda: matmul_check_collisions(a, b), number=1000))

    # Matmul check drivable
    a = np.random.rand(1, 83, 1, 1, 2)
    b = np.random.rand(1, 83, 41, 2, 1)

    c = np.matmul(a, b)
    d = matmul_check_drivable(a, b)

    print("Matmul check drivable")
    print(np.allclose(c, d), c.shape == d.shape)
    print("Numpy: ", timeit.timeit(lambda: np.matmul(a, b), number=1000))
    print("Numba: ", timeit.timeit(lambda: matmul_check_drivable(a, b), number=1000))

But I feel like I’m butchering Numba and the implementation could be a lot smarter. So I would be eager to learn how to make it better.

Here are the results so far:

Matmul 4D
True True
Numpy:  0.0028127990003667946
Numba:  0.0011642999998002779
NumbaPerformanceWarning: np.dot() is faster on contiguous arrays, called on (Array(float64, 1, 'C', False, aligned=True), Array(float64, 1, 'A', False, aligned=True))
  sum_ = np.dot(a_row, b_col)
Matmul check collisions
True True
Numpy:  0.010754999000255339
Numba:  0.0338912980000714
Matmul check drivable
True True
Numpy:  0.04238159699980315
Numba:  0.25276338699995904

Thank you

Hi @charraut

From your code I assume that your goal is not to write a fast np.matmul in Numba, but only to increase the speed for your specific application. Is that correct?

If so, you can take the completely naive approach and write out all the loops. Using np.dot for small 1D arrays is generally not a good idea, since there is not much to optimize in terms of memory access patterns. So you are just paying for the overhead without any speed advantage.

With these changes, you already achieve the speed of Numpy. If you always have the same shapes for your inputs, you can implicitly pass this information to the compiler, which enables remarkable optimizations. You can see what I mean in the code below.

import numpy as np
import numba as nb 


def generate_matmul_4d(A=None, B=None):
    def impl(a, b):
        (s1, s2, s3, s4) = a.shape if A is None else A.shape
        (         _, t4) = b.shape if B is None else B.shape
        c = np.zeros((s1, s2, s3, t4))
        for i in range(s1):
            for j in range(s2):
                for k in range(s3):
                    for l in range(t4):
                        for m in range(s4):
                            c[i, j, k, l] += a[i, j, k, m] * b[m, l]
        return c
    return nb.njit(impl)


def generate_matmul_check_collisions(A=None, B=None):
    def impl(a, b):
        (s1, _, s3, s4, _) = a.shape if A is None else A.shape
        (_, t2, _, t4, t5) = b.shape if B is None else B.shape
        c = np.zeros((s1, t2, s3, s4, t5))
        for i in range(s1):
            for k in range(s3):
                for j in range(t2):
                    for l in range(s4):
                        for m in range(t5):
                            for n in range(t4):
                                c[i, j, k, l, m] += a[i, 0, k, l, n] * b[i, j, k, n, m] 
        return c
    return nb.njit(impl)


def generate_matmul_check_drivable(A=None, B=None):
    def impl(a, b):
        (s1, s2, _, s4, _) = a.shape if A is None else A.shape
        (_, _, t3, t4, t5) = b.shape if B is None else B.shape
        c = np.zeros((s1, s2, t3, s4, t5))
        for i in range(s1):
            for j in range(s2):
                for k in range(t3):
                    for l in range(s4):
                        for m in range(t4):
                            c[i, j, k, l, 0] += a[i, j, l, 0, m] * b[i, j, k, m, 0]
        return c
    return nb.njit(impl)


print(" matmul_4d ".center(60, "-"))
a = np.random.rand(1, 10, 2, 2)
b = np.random.rand(2, 4)

matmul_4d_1 = generate_matmul_4d()
matmul_4d_2 = generate_matmul_4d(a, b)

assert np.allclose(np.matmul(a, b), matmul_4d_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_4d_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_4d_1(a, b)
%timeit matmul_4d_2(a, b)


print(" matmul_check_collisions ".center(60, "-"))
a = np.random.rand(1, 1, 10, 2, 2)
b = np.random.rand(1, 23, 10, 2, 1)

matmul_check_collisions_1 = generate_matmul_check_collisions()
matmul_check_collisions_2 = generate_matmul_check_collisions(a, b)

assert np.allclose(np.matmul(a, b), matmul_check_collisions_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_check_collisions_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_check_collisions_1(a, b)
%timeit matmul_check_collisions_2(a, b)


print(" matmul_check_drivable ".center(60, "-"))
a = np.random.rand(1, 83, 1, 1, 2)
b = np.random.rand(1, 83, 41, 2, 1)

matmul_check_drivable_1 = generate_matmul_check_drivable()
matmul_check_drivable_2 = generate_matmul_check_drivable(a, b)

assert np.allclose(np.matmul(a, b), matmul_check_drivable_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_check_drivable_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_check_drivable_1(a, b)
%timeit matmul_check_drivable_2(a, b)

# ------------------------ matmul_4d -------------------------
# 2.59 µs ± 46.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 1.58 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# 1.11 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# ----------------- matmul_check_collisions ------------------
# 11.5 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 5.75 µs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 1.91 µs ± 22.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# ------------------ matmul_check_drivable -------------------
# 46.3 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 34.5 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 9.15 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
2 Likes

Hey @sschaer

Thanks a lot for your answer and taking the time to explain. It is very clear and I can get it running like I want.

Hope you have a good day