Batched Dot Products

I’m wondering what the best way is to do batched multiply in Numba. I’m getting errors with the following code.

import numba as nb
import numpy as np

def multiply(a, b):
    return np.dot(a, b)

@nb.jit(nopython=True)
def nb_multiply(a, b):
    return np.dot(a, b)

a = np.random.uniform(size=(10, 4, 4))
b = np.ones(4)

print(multiply(a, b))
print(nb_multiply(a, b))

Thanks!

Hey @fishbotics,

it seems that numba has no implemention for the np.dot function with higher dimensions than two.

# from numba's  linalg.py
# https://github.com/numba/numba/blob/main/numba/np/linalg.py
def dot_2_impl(name, left, right):
    if isinstance(left, types.Array) and isinstance(right, types.Array):
        @intrinsic
        def _impl(typingcontext, left, right):
            ndims = (left.ndim, right.ndim)

            def _dot2_codegen(context, builder, sig, args):
                ensure_blas()

                with make_contiguous(context, builder, sig, args) as (sig, args):
                    if ndims == (2, 2):
                        return dot_2_mm(context, builder, sig, args)
                    elif ndims == (2, 1):
                        return dot_2_mv(context, builder, sig, args)
                    elif ndims == (1, 2):
                        return dot_2_vm(context, builder, sig, args)
                    elif ndims == (1, 1):
                        return dot_2_vv(context, builder, sig, args)
                    else:
                        raise AssertionError('unreachable')

As a workaround you could iterate over the batches and execute the supported 2D operations if you know the dimensions of the input arrays.

import numba as nb
import numpy as np

def multiply(a, b):
    return np.dot(a, b)

@nb.jit(nopython=True)
def batch_multiply(a, b):
    if a.ndim != 3:
        raise ValueError("Input 'a' must be a 3D array")
    if b.ndim >= 2:
        raise ValueError("Input 'b' must be a 1D array")
    rows, cols, _ = a.shape
    res = np.empty((rows, cols), dtype=a.dtype)
    for i in range(a.shape[0]):
        res[i] = a[i] @ b
    return res

a = np.random.uniform(size=(10, 3, 4))
b = np.ones(4)

res_np = multiply(a, b)
res_nb = batch_multiply(a, b)
np.allclose(res_np, res_nb)

BATCHES = 1000
ROWS = 1000
COLS = 50

a = np.random.uniform(size=(BATCHES, ROWS, COLS))
b = np.ones(COLS)

%timeit np.dot(a, b)
%timeit batch_multiply(a, b)
# 49 ms ± 566 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# 24.4 ms ± 731 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
1 Like

Hi @fishbotics

Matrix-vector multiplication can always be done the naive way like so:

@nb.njit(fastmath=True, parallel=False)
def nb_multiply(a, b):
    n, m, l = a.shape
    out  = np.empty((n, m))
    for i in nb.prange(n):
        for j in range(m):
            val = 0
            for k in range(l):
                val += a[i, j, k] * b[k]
            out[i, j] = val 
    return out 

It will even almost always (maybe even always) perform best because you have full control over parallelization and no overhead.

Here also something related that you might find interesting: Help needed to re-implement np.matmul for 4D and 5D matrix - #2 by sschaer

1 Like

Thanks for the tips @Oyibo and @sschaer!

@sschaer is there a reason you did not use prange in both the other and inner loops? Is it invalid to use multiple times?

It is not invalid to do it, but it will simply be ignored. See the part about loop serialization in the doc: Automatic parallelization with @jit — Numba 0.50.1 documentation