# Help needed to re-implement np.matmul for 4D and 5D matrix

From your code I assume that your goal is not to write a fast `np.matmul` in Numba, but only to increase the speed for your specific application. Is that correct?

If so, you can take the completely naive approach and write out all the loops. Using `np.dot` for small 1D arrays is generally not a good idea, since there is not much to optimize in terms of memory access patterns. So you are just paying for the overhead without any speed advantage.

With these changes, you already achieve the speed of Numpy. If you always have the same shapes for your inputs, you can implicitly pass this information to the compiler, which enables remarkable optimizations. You can see what I mean in the code below.

``````import numpy as np
import numba as nb

def generate_matmul_4d(A=None, B=None):
def impl(a, b):
(s1, s2, s3, s4) = a.shape if A is None else A.shape
(         _, t4) = b.shape if B is None else B.shape
c = np.zeros((s1, s2, s3, t4))
for i in range(s1):
for j in range(s2):
for k in range(s3):
for l in range(t4):
for m in range(s4):
c[i, j, k, l] += a[i, j, k, m] * b[m, l]
return c
return nb.njit(impl)

def generate_matmul_check_collisions(A=None, B=None):
def impl(a, b):
(s1, _, s3, s4, _) = a.shape if A is None else A.shape
(_, t2, _, t4, t5) = b.shape if B is None else B.shape
c = np.zeros((s1, t2, s3, s4, t5))
for i in range(s1):
for k in range(s3):
for j in range(t2):
for l in range(s4):
for m in range(t5):
for n in range(t4):
c[i, j, k, l, m] += a[i, 0, k, l, n] * b[i, j, k, n, m]
return c
return nb.njit(impl)

def generate_matmul_check_drivable(A=None, B=None):
def impl(a, b):
(s1, s2, _, s4, _) = a.shape if A is None else A.shape
(_, _, t3, t4, t5) = b.shape if B is None else B.shape
c = np.zeros((s1, s2, t3, s4, t5))
for i in range(s1):
for j in range(s2):
for k in range(t3):
for l in range(s4):
for m in range(t4):
c[i, j, k, l, 0] += a[i, j, l, 0, m] * b[i, j, k, m, 0]
return c
return nb.njit(impl)

print(" matmul_4d ".center(60, "-"))
a = np.random.rand(1, 10, 2, 2)
b = np.random.rand(2, 4)

matmul_4d_1 = generate_matmul_4d()
matmul_4d_2 = generate_matmul_4d(a, b)

assert np.allclose(np.matmul(a, b), matmul_4d_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_4d_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_4d_1(a, b)
%timeit matmul_4d_2(a, b)

print(" matmul_check_collisions ".center(60, "-"))
a = np.random.rand(1, 1, 10, 2, 2)
b = np.random.rand(1, 23, 10, 2, 1)

matmul_check_collisions_1 = generate_matmul_check_collisions()
matmul_check_collisions_2 = generate_matmul_check_collisions(a, b)

assert np.allclose(np.matmul(a, b), matmul_check_collisions_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_check_collisions_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_check_collisions_1(a, b)
%timeit matmul_check_collisions_2(a, b)

print(" matmul_check_drivable ".center(60, "-"))
a = np.random.rand(1, 83, 1, 1, 2)
b = np.random.rand(1, 83, 41, 2, 1)

matmul_check_drivable_1 = generate_matmul_check_drivable()
matmul_check_drivable_2 = generate_matmul_check_drivable(a, b)

assert np.allclose(np.matmul(a, b), matmul_check_drivable_1(a, b))
assert np.allclose(np.matmul(a, b), matmul_check_drivable_2(a, b))

%timeit np.matmul(a, b)
%timeit matmul_check_drivable_1(a, b)
%timeit matmul_check_drivable_2(a, b)

# ------------------------ matmul_4d -------------------------
# 2.59 µs ± 46.1 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 1.58 µs ± 23.1 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# 1.11 µs ± 14.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# ----------------- matmul_check_collisions ------------------
# 11.5 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 5.75 µs ± 124 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
# 1.91 µs ± 22.4 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
# ------------------ matmul_check_drivable -------------------
# 46.3 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 34.5 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
# 9.15 µs ± 170 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
``````
2 Likes