I’m currently developing a numerical solver package for a specific class of PDEs. My initial approach used SciPy’s ODE solvers, but that became too slow for 2D/3D problems with fine discretizations. I decided to try Numba and numbakit-ode, but unfortunately, my code ended up running even slower. I’m new to Numba and suspect it’s an implementation issue on my end. I’d appreciate any guidance or suggestions on how to get better performance out of Numba (and numbakit-ode).
Here is my code:
import cmath
import numpy as np
import numba as nb
import nbkode
###########
FFT and IFFT
###########
@nb.jit
def ilog2(n):
result = -1
if n < 0:
n = -n
while n > 0:
n >>= 1
result += 1
return result
@nb.njit(fastmath=True)
def reverse_bits(val, width):
result = 0
for _ in range(width):
result = (result << 1) | (val & 1)
val >>= 1
return result
@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
levels = ilog2(n)
e_arr = np.empty_like(arr)
coeff = (-2j if direct else 2j) * cmath.pi / n
for i in range(n):
e_arr[i] = cmath.exp(coeff * i)
result = np.empty_like(arr)
for i in range(n):
result[i] = arr[reverse_bits(i, levels)]
# Radix-2 decimation-in-time FFT
size = 2
while size <= n:
half_size = size // 2
step = n // size
for i in range(0, n, size):
k = 0
for j in range(i, i + half_size):
temp = result[j + half_size] * e_arr[k]
result[j + half_size] = result[j] - temp
result[j] += temp
k += step
size *= 2
return result
@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
"""1D FFT for arbitrary inputs using chirp z-transform"""
arr = np.asarray(arr, dtype=np.complex128)
n = len(arr)
m = 1 << (ilog2(n) + 2)
e_arr = np.empty(n, dtype=np.complex128)
for i in range(n):
e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
result = np.zeros(m, dtype=np.complex128)
result[:n] = arr * e_arr
coeff = np.zeros_like(result)
coeff[:n] = e_arr.conjugate()
coeff[-n + 1:] = e_arr[:0:-1].conjugate()
return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m
@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)
@nb.njit(fastmath=True)
def fft_1d(arr):
# Force arr to be 1D
arr_1d = np.ravel(arr) # or arr.flatten()
n = arr_1d.size
if not n & (n - 1):
return fft_1d_radix2_rbi(arr_1d)
else:
return fft_1d_arb(arr_1d)
@nb.njit(fastmath=True)
def ifft_1d(arr):
arr_1d = np.ravel(arr)
n = arr_1d.size
arr_conj = np.conjugate(arr_1d)
tmp = fft_1d(arr_conj)
return np.conjugate(tmp) / n
##########
Equation
##########
def fourier_initial_condition(x, **kwargs):
N = kwargs.get('N', 10)
A_range = kwargs.get('A_range', [-0.5, 0.5])
l_range = kwargs.get('l_range', [1, 3])
p_range = kwargs.get('p_range', [0, 2*np.pi])
## Sample parameters.
A = np.random.uniform(low=A_range[0], high=A_range[1], size=N)
l = np.random.randint(low=l_range[0], high=l_range[1], size=N)
p = np.random.uniform(low=p_range[0], high=p_range[1], size=N)
initial_condition = np.sum(A[:, np.newaxis] * np.sin(2 * np.pi * l[:, np.newaxis] * x / (x[-1] - x[0]) + p[:, np.newaxis]), axis=0)
return initial_condition
class BurgersEquation():
## Constants for solving the Burgers' equation.
SOLVER_METHOD = 'RK23'
RTOL = 1e-8
ATOL = 1e-9
@staticmethod
def domain_params():
return {
'ts': 0.0,
'te': 10.0,
'xs': 0,
'xe': 2*np.pi,
'dt': 0.2,
'nx': 2048
}
@staticmethod
def extra_params():
return (0.1, )
@staticmethod
def ic(x):
return fourier_initial_condition(x, N=20, A_range=[-0.5, 0.5], l_range=[3, 6], p_range=[0, 2*np.pi])
@staticmethod
def rhs(t, u, k):
## FFT of u.
u_fft = fft_1d(u)
## First derivative of u with respect to x.
u_x = ifft_1d(1j * k * u_fft).real
## Second derivative of u with respect to x.
u_xx = ifft_1d(-k**2 * u_fft).real
return -u * u_x + 0.1 * u_xx
@staticmethod
def solve(T, X, ic, p, traj=True):
## Solve Burgers' equation.
if not traj:
T = [T[0], T[-1]]
## Wavenumbers.
k = 2 * np.pi * np.fft.fftfreq(len(X), d=(X[1] - X[0]))
# sol = solve_ivp(fun=BurgersEquation.rhs,
solver = nbkode.RungeKutta23(BurgersEquation.rhs,
t0=T[0],
y0=ic,
params=(k,),
rtol=BurgersEquation.RTOL,
atol=BurgersEquation.ATOL)
ts, ys = solver.run(T)
return ys.T
import time
burgers_equation = BurgersEquation()
for nx in [64, 128, 256, 512, 1024]:#, 2048]:
first = True
total_time = 0
## Set domain.
T = np.arange(burgers_equation.domain_params()['ts'], burgers_equation.domain_params()['te'] + burgers_equation.domain_params()['dt'], burgers_equation.domain_params()['dt'])
X = np.linspace(burgers_equation.domain_params()['xs'], burgers_equation.domain_params()['xe'], nx)
## Get diffusion parameter and initial condition.
p = burgers_equation.extra_params()
for _ in range(10):
ic = burgers_equation.ic(X)
## Solve.
if first:
first = False
u = burgers_equation.solve(T, X, ic, p)
else:
start = time.time()
u = burgers_equation.solve(T, X, ic, p)
end = time.time()
total_time += end - start
print(f'nx: {nx}, time taken: {total_time/10} seconds.')