How to solve a batch of linear systems using numpy.linalg.solve()?

I have a batch of M linear systems where the coefficient matrices are stored as a 3-D array A (with shape (M,N,N)) and the right-hand side vectors are stored as a 2-D array b (with shape (M,N)). In plain NumPy, I can solve for the solution vectors by x = np.linalg.solve(A, b), where x is of the shape (M,N). I tried to use this in a numba.jit-ed function, i.e.,

import numpy as np
from numba import jit

@jit
def jit_linsolve(A, b):
    return np.linalg.solve(A, b)

However, when running jit_linsolve(A, b), I get the following:

There are 2 candidate implementations:
           - Of which 2 did not match due to:
           Overload in function `solve_impl`: File: numba\np\linalg.py: Line 1697.
                With argument(s): `(array(float64, 3d, C), array(float64, 2d, C))`:
              Rejected as the implementation raised a specific error:
                  TypingError: np.linalg.solve() only supported on 2-D arrays.

When running jit_linsolve(A[0], b[0]), I get an array with shape (N,).

If I understand correctly, numba.jit simply does not support numpy.linalg.solve() for a batch of linear systems yet. My current idea is to simply loop through the batches, although I do not know how to make it efficient. What are other options/workarounds for accomplishing this?

Hi @christian-cahig,

You are correct, Numba’s np.linalg support operates on arrays with a maximum of two dimensions. This is due to higher dimensional array support having not been implemented in Numba yet, IIRC this is in part due to there not being a performance gain for doing so, it’s just a convenience. In both Numba and NumPy the core computational work is done by LAPACK, which also supports a maximum of two dimensions. I think NumPy’s np.linalg.solve has a loop to push the 2d systems into the LAPACK routines via a GUFUNC, this permitting the higher dimensionality support, but it is fundamentally just a loop. In Numba to do something similar, this sort of thing should work for your use case:

import numpy as np
from numba import njit

@njit
def jit_linsolve(A, b):
    ret = np.empty((m, n))
    for i in range(m):
        ret[i, :] = np.linalg.solve(A[i], b[i])
    return ret

m = 5
n = 3

np.random.seed(0)
A = np.random.random((m, n, n))
b = np.random.random((m, n))

X = jit_linsolve(A, b)
Y = np.linalg.solve(A, b)
np.testing.assert_allclose(X, Y)

I’m not sure that there’s anything trivially available to make it more efficient. Each one of the n x n systems in A needs to be solved against the corresponding RHS in b, and the bulk of the work will be done in LAPACK whether NumPy or Numba drives the loop.

Hope this helps?

1 Like

Yes, that helps a lot! Thank you!

No problem, glad you’ve got something working.