Hey @fishbotics,
it seems that numba has no implemention for the np.dot function with higher dimensions than two.
# from numba's linalg.py
# https://github.com/numba/numba/blob/main/numba/np/linalg.py
def dot_2_impl(name, left, right):
if isinstance(left, types.Array) and isinstance(right, types.Array):
@intrinsic
def _impl(typingcontext, left, right):
ndims = (left.ndim, right.ndim)
def _dot2_codegen(context, builder, sig, args):
ensure_blas()
with make_contiguous(context, builder, sig, args) as (sig, args):
if ndims == (2, 2):
return dot_2_mm(context, builder, sig, args)
elif ndims == (2, 1):
return dot_2_mv(context, builder, sig, args)
elif ndims == (1, 2):
return dot_2_vm(context, builder, sig, args)
elif ndims == (1, 1):
return dot_2_vv(context, builder, sig, args)
else:
raise AssertionError('unreachable')
As a workaround you could iterate over the batches and execute the supported 2D operations if you know the dimensions of the input arrays.
import numba as nb
import numpy as np
def multiply(a, b):
return np.dot(a, b)
@nb.jit(nopython=True)
def batch_multiply(a, b):
if a.ndim != 3:
raise ValueError("Input 'a' must be a 3D array")
if b.ndim >= 2:
raise ValueError("Input 'b' must be a 1D array")
rows, cols, _ = a.shape
res = np.empty((rows, cols), dtype=a.dtype)
for i in range(a.shape[0]):
res[i] = a[i] @ b
return res
a = np.random.uniform(size=(10, 3, 4))
b = np.ones(4)
res_np = multiply(a, b)
res_nb = batch_multiply(a, b)
np.allclose(res_np, res_nb)
BATCHES = 1000
ROWS = 1000
COLS = 50
a = np.random.uniform(size=(BATCHES, ROWS, COLS))
b = np.ones(COLS)
%timeit np.dot(a, b)
%timeit batch_multiply(a, b)
# 49 ms ± 566 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 24.4 ms ± 731 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)