I am trying out different ways to solve a linear system using numba. As of now I am using scipy.linalg.cho_factor
and cho_solve
due to it being faster than the njit
-ed version of np.linalg.solve
. I have had success using numba on a bunch of other parts of my code, and would love to improve the solve as well.
(Excuse me if this has been discussed elsewhere in the discoure, I could not find a good post on it after some searches.)
How can I improve the solve-examples below, using numba?
Context
Want to solve(K,data)
fast (approximately 2e9 times, where K.shape = (n,n), data.shape=(n,), and n vary between 1 and 400. In every run K and data have different values. I have a cluster of remote workers available, each w/ 1 core per worker, that I distribute the calculations to, and with scipy.linalg
this takes more than 2 hours. Would love to improve the calculation.
Example
import numpy as np
from numba import njit
from scipy.linalg import cho_factor, cho_solve
# Usually I use the options below, and cache functions on the remote worker-side,
# ahead of the computation.
def options():
return dict(
parallel = False, # Since only 1 core per worker
fastmath = True,
cache = True,
nogil = True,
error_model = "numpy"
)
@njit(**options())
def numpy_solve(K, data):
return np.linalg.solve(K, data)
@njit(**options())
def solve_cholesky(K, data):
LNumba = cholesky_numba(K)
y = np.linalg.solve(LNumba, data)
return np.linalg.solve(LNumba.T.conj(), y)
@njit(**options())
def cholesky_numba(K):
n = K.shape[0]
L = np.zeros_like(K)
for i in range(n):
for j in range(i+1):
s = 0
for k in range(j):
s += L[i][k] * L[j][k]
if (i == j):
L[i][j] = (K[i][i] - s) ** 0.5
else:
L[i][j] = (1.0 / L[j][j] * (K[i][j] - s))
return L
def scipysolve(K, data):
c_and_lower = cho_factor( K, lower=True, overwrite_a=True, check_finite=False )
return cho_solve(c_and_lower, data, check_finite = False)
%timeit f0 = numpy_solve(K, data)
%timeit f1 = solve_cholesky(K, data)
%timeit f2 = scipysolve(K, data)
23.2 µs ± 3.07 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
49.6 µs ± 2.98 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
10.1 µs ± 111 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Arrays used, for reproducability
# Just copy-pasting an example in here, to make it reproducible.
data = np.array([-1.83533745e-02, -2.62924084e-02, -4.46773166e-03, 2.48877247e-02,
3.36381269e-02, -1.34949865e-02, -2.57439008e-02, 1.41652167e-02,
3.41128634e-03, -2.29770657e-02, -1.70633180e-02, -3.16427878e-02,
-1.29482122e-02, 4.46694804e-02, -1.40248402e-06, -1.10435528e-03,
1.49894427e-02, 2.48672250e-02])
K = np.array([[ 1.88679408e-03, 1.59684967e-03, 1.33447234e-03,
1.15446881e-03, 1.07275824e-03, 1.70223626e-04,
1.61994314e-04, 1.51883781e-04, 1.44565016e-04,
0.00000000e+00, 1.12002441e-05, 3.66292098e-05,
4.10665050e-05, 4.95428828e-06, 5.06540388e-05,
4.52545932e-05, 3.00509432e-05, 2.60088220e-05],
[ 1.59684967e-03, 1.87990962e-03, 1.56633850e-03,
1.34747409e-03, 1.23710933e-03, 1.91337105e-04,
1.83811650e-04, 1.77607485e-04, 1.69435316e-04,
1.11976194e-05, 0.00000000e+00, 3.14312422e-05,
3.15812479e-05, -8.26493526e-07, 5.02506560e-05,
4.57856347e-05, 3.38848683e-05, 2.92687498e-05],
[ 1.33447234e-03, 1.56633850e-03, 1.85988039e-03,
1.58418356e-03, 1.38911730e-03, 2.23152978e-04,
2.15480031e-04, 2.08326339e-04, 1.99154811e-04,
3.65952793e-05, 3.14094873e-05, 0.00000000e+00,
9.65482459e-07, 8.13753464e-06, 4.64415663e-05,
4.35303679e-05, 3.39330403e-05, 2.90651199e-05],
[ 1.15446881e-03, 1.34747409e-03, 1.58418356e-03,
1.83985842e-03, 1.51937560e-03, 2.57253488e-04,
2.49319809e-04, 2.28152631e-04, 2.19356268e-04,
4.09994186e-05, 3.15370472e-05, 9.64798959e-07,
0.00000000e+00, 4.92499208e-05, 5.45643572e-05,
5.08831811e-05, 3.11473273e-05, 2.65390476e-05],
[ 1.07275824e-03, 1.23710933e-03, 1.38911730e-03,
1.51937560e-03, 1.86880224e-03, 2.41340451e-04,
2.40780764e-04, 2.24390938e-04, 2.18382847e-04,
4.95123547e-06, -8.26177854e-07, 8.14006066e-06,
4.93001100e-05, 0.00000000e+00, 8.97243250e-05,
8.14007318e-05, 4.37606403e-05, 3.86739545e-05],
[ 1.70223626e-04, 1.91337105e-04, 2.23152978e-04,
2.57253488e-04, 2.41340451e-04, 1.53351813e-03,
1.32960595e-03, 3.00014346e-04, 3.05442260e-04,
4.99018110e-05, 4.95160225e-05, 4.57943154e-05,
5.38420166e-05, 8.84463921e-05, 0.00000000e+00,
-1.08927502e-05, -1.86346107e-04, -1.94558012e-04],
[ 1.61994314e-04, 1.83811650e-04, 2.15480031e-04,
2.49319809e-04, 2.40780764e-04, 1.32960595e-03,
1.54340857e-03, 3.19948671e-04, 3.29223469e-04,
4.46057170e-05, 4.51397226e-05, 4.29459961e-05,
5.02356654e-05, 8.02830498e-05, -1.08984108e-05,
0.00000000e+00, -1.85542909e-04, -1.97926982e-04],
[ 1.51883781e-04, 1.77607485e-04, 2.08326339e-04,
2.28152631e-04, 2.24390938e-04, 3.00014346e-04,
3.19948671e-04, 1.75840993e-03, 1.53005684e-03,
2.99102281e-05, 3.37341060e-05, 3.38054618e-05,
3.10522053e-05, 4.35825842e-05, -1.88269389e-04,
-1.87360536e-04, 0.00000000e+00, 5.74482718e-06],
[ 1.44565016e-04, 1.69435316e-04, 1.99154811e-04,
2.19356268e-04, 2.18382847e-04, 3.05442260e-04,
3.29223469e-04, 1.53005684e-03, 1.76364514e-03,
2.58923593e-05, 2.91445196e-05, 2.89617996e-05,
2.64634414e-05, 3.85245183e-05, -1.96606482e-04,
-1.99907038e-04, 5.74600888e-06, 0.00000000e+00],
[ 0.00000000e+00, 1.11976194e-05, 3.65952793e-05,
4.09994186e-05, 4.95123547e-06, 4.99018110e-05,
4.46057170e-05, 2.99102281e-05, 2.58923593e-05,
1.79289224e-03, 1.44846528e-03, 1.16324154e-03,
1.04458060e-03, 1.01210499e-03, 2.25202867e-04,
1.98037880e-04, -5.25430505e-05, -3.61875293e-05],
[ 1.12002441e-05, 0.00000000e+00, 3.14094873e-05,
3.15370472e-05, -8.26177854e-07, 4.95160225e-05,
4.51397226e-05, 3.37341060e-05, 2.91445196e-05,
1.44846528e-03, 1.78718795e-03, 1.43743698e-03,
1.27650446e-03, 1.17742592e-03, 2.69368408e-04,
2.37871575e-04, -5.39067389e-05, -3.51288038e-05],
[ 3.66292098e-05, 3.14312422e-05, 0.00000000e+00,
9.64798959e-07, 8.14006066e-06, 4.57943154e-05,
4.29459961e-05, 3.38054618e-05, 2.89617996e-05,
1.16324154e-03, 1.43743698e-03, 1.77059676e-03,
1.52637970e-03, 1.27899609e-03, 3.22252555e-04,
2.85694021e-04, -5.48020161e-05, -3.31806425e-05],
[ 4.10665050e-05, 3.15812479e-05, 9.65482459e-07,
0.00000000e+00, 4.93001100e-05, 5.38420166e-05,
5.02356654e-05, 3.10522053e-05, 2.64634414e-05,
1.04458060e-03, 1.27650446e-03, 1.52637970e-03,
1.75401853e-03, 1.39040946e-03, 3.66558052e-04,
3.25109958e-04, -7.11544451e-05, -4.70177726e-05],
[ 4.95428828e-06, -8.26493526e-07, 8.13753464e-06,
4.92499208e-05, 0.00000000e+00, 8.84463921e-05,
8.02830498e-05, 4.35825842e-05, 3.85245183e-05,
1.01210499e-03, 1.17742592e-03, 1.27899609e-03,
1.39040946e-03, 1.77798632e-03, 3.70092071e-04,
3.28740223e-04, -1.05561515e-04, -8.11476293e-05],
[ 5.06540388e-05, 5.02506560e-05, 4.64415663e-05,
5.45643572e-05, 8.97243250e-05, 0.00000000e+00,
-1.08984108e-05, -1.88269389e-04, -1.96606482e-04,
2.25202867e-04, 2.69368408e-04, 3.22252555e-04,
3.66558052e-04, 3.70092071e-04, 1.50146129e-03,
1.28998516e-03, -9.96000586e-05, -8.10580453e-05],
[ 4.52545932e-05, 4.57856347e-05, 4.35303679e-05,
5.08831811e-05, 8.14007318e-05, -1.08927502e-05,
0.00000000e+00, -1.87360536e-04, -1.99907038e-04,
1.98037880e-04, 2.37871575e-04, 2.85694021e-04,
3.25109958e-04, 3.28740223e-04, 1.28998516e-03,
1.50957561e-03, -8.80310616e-05, -7.29915038e-05],
[ 3.00509432e-05, 3.38848683e-05, 3.39330403e-05,
3.11473273e-05, 4.37606403e-05, -1.86346107e-04,
-1.85542909e-04, 0.00000000e+00, 5.74600888e-06,
-5.25430505e-05, -5.39067389e-05, -5.48020161e-05,
-7.11544451e-05, -1.05561515e-04, -9.96000586e-05,
-8.80310616e-05, 1.68665623e-03, 1.44540841e-03],
[ 2.60088220e-05, 2.92687498e-05, 2.90651199e-05,
2.65390476e-05, 3.86739545e-05, -1.94558012e-04,
-1.97926982e-04, 5.74482718e-06, 0.00000000e+00,
-3.61875293e-05, -3.51288038e-05, -3.31806425e-05,
-4.70177726e-05, -8.11476293e-05, -8.10580453e-05,
-7.29915038e-05, 1.44540841e-03, 1.69098207e-03]])