I have a batch of M
linear systems where the coefficient matrices are stored as a 3-D array A
(with shape (M,N,N)
) and the right-hand side vectors are stored as a 2-D array b
(with shape (M,N)
). In plain NumPy, I can solve for the solution vectors by x = np.linalg.solve(A, b)
, where x
is of the shape (M,N)
. I tried to use this in a numba.jit
-ed function, i.e.,
import numpy as np
from numba import jit
@jit
def jit_linsolve(A, b):
return np.linalg.solve(A, b)
However, when running jit_linsolve(A, b)
, I get the following:
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function `solve_impl`: File: numba\np\linalg.py: Line 1697.
With argument(s): `(array(float64, 3d, C), array(float64, 2d, C))`:
Rejected as the implementation raised a specific error:
TypingError: np.linalg.solve() only supported on 2-D arrays.
When running jit_linsolve(A[0], b[0])
, I get an array with shape (N,)
.
If I understand correctly, numba.jit
simply does not support numpy.linalg.solve()
for a batch of linear systems yet. My current idea is to simply loop through the batches, although I do not know how to make it efficient. What are other options/workarounds for accomplishing this?