MRE
import numpy as np
from numba import jit, f8
def function( A1, A2, A3, B1, B2, B3, const1, const2, const3 ):
x = ((A1-B1)/const1)**2 + ((A2-B2)/const2)**2 + ((A3-B3)/const3)**2
return x
@jit(f8[:,:]( f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8, f8, f8 ), nopython=True, parallel=False)
def function_jit( A1, A2, A3, B1, B2, B3, const1, const2, const3 ):
x = ((A1-B1)/const1)**2 + ((A2-B2)/const2)**2 + ((A3-B3)/const3)**2
return x
@jit(f8[:,:]( f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8[:,:], f8, f8, f8 ), nopython=True, fastmath=True, parallel=False)
def function_jit_loop_fastmath( A1, A2, A3, B1, B2, B3, const1, const2, const3 ):
x = np.empty_like(A1)
for i in nb.prange(A1.shape[0]):
for j in range(A1.shape[1]):
x[i,j] = ((A1[i,j]-B1[i,j])/const1)**2 + ((A2[i,j]-B2[i,j])/const2)**2 + ((A3[i,j]-B3[i,j])/const3)**2
return x
n = 200
A1 = np.arange(n**2, dtype=np.float64).reshape(n, n)
A2 = A1 * np.random.uniform()
A3 = A1 * np.random.uniform()
B1 = A1.T
B2 = A2.T
B3 = A3.T
const1 = np.random.uniform()*10
const2 = np.random.uniform()*10
const3 = np.random.uniform()*100
%timeit function( A1, A2, A3, B1, B2, B3, const1, const2, const3 )
%timeit function_jit( A1, A2, A3, B1, B2, B3, const1, const2, const3 )
%timeit function_jit_loop_fastmath( A1, A2, A3, B1, B2, B3, const1, const2, const3 )
print(np.allclose( function( A1, A2, A3, B1, B2, B3, const1, const2, const3 ), function_jit( A1, A2, A3, B1, B2, B3, const1, const2, const3 )))
print(np.allclose( function( A1, A2, A3, B1, B2, B3, const1, const2, const3 ), function_jit_loop_fastmath( A1, A2, A3, B1, B2, B3, const1, const2, const3 )))
416 µs ± 2.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
624 µs ± 3.11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
261 µs ± 4.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
True
True
After reading this answer I created a loop. Also fastmath=True
seems to help. Together this cuts 38% off the original runtime. Is it possible to cut more from the original runtime?