 # How to solve a batch of linear systems using numpy.linalg.solve()?

I have a batch of `M` linear systems where the coefficient matrices are stored as a 3-D array `A` (with shape `(M,N,N)`) and the right-hand side vectors are stored as a 2-D array `b` (with shape `(M,N)`). In plain NumPy, I can solve for the solution vectors by `x = np.linalg.solve(A, b)`, where `x` is of the shape `(M,N)`. I tried to use this in a `numba.jit`-ed function, i.e.,

``````import numpy as np
from numba import jit

@jit
def jit_linsolve(A, b):
return np.linalg.solve(A, b)
``````

However, when running `jit_linsolve(A, b)`, I get the following:

``````There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function `solve_impl`: File: numba\np\linalg.py: Line 1697.
With argument(s): `(array(float64, 3d, C), array(float64, 2d, C))`:
Rejected as the implementation raised a specific error:
TypingError: np.linalg.solve() only supported on 2-D arrays.
``````

When running `jit_linsolve(A, b)`, I get an array with shape `(N,)`.

If I understand correctly, `numba.jit` simply does not support `numpy.linalg.solve()` for a batch of linear systems yet. My current idea is to simply loop through the batches, although I do not know how to make it efficient. What are other options/workarounds for accomplishing this?

You are correct, Numba’s `np.linalg` support operates on arrays with a maximum of two dimensions. This is due to higher dimensional array support having not been implemented in Numba yet, IIRC this is in part due to there not being a performance gain for doing so, it’s just a convenience. In both Numba and NumPy the core computational work is done by LAPACK, which also supports a maximum of two dimensions. I think NumPy’s `np.linalg.solve` has a loop to push the 2d systems into the LAPACK routines via a `GUFUNC`, this permitting the higher dimensionality support, but it is fundamentally just a loop. In Numba to do something similar, this sort of thing should work for your use case:

``````import numpy as np
from numba import njit

@njit
def jit_linsolve(A, b):
ret = np.empty((m, n))
for i in range(m):
ret[i, :] = np.linalg.solve(A[i], b[i])
return ret

m = 5
n = 3

np.random.seed(0)
A = np.random.random((m, n, n))
b = np.random.random((m, n))

X = jit_linsolve(A, b)
Y = np.linalg.solve(A, b)
np.testing.assert_allclose(X, Y)
``````

I’m not sure that there’s anything trivially available to make it more efficient. Each one of the `n x n` systems in `A` needs to be solved against the corresponding RHS in `b`, and the bulk of the work will be done in LAPACK whether NumPy or Numba drives the loop.

Hope this helps?

1 Like

Yes, that helps a lot! Thank you!

No problem, glad you’ve got something working.