Slower Performance with `numbakit-ode`

I’m currently developing a numerical solver package for a specific class of PDEs. My initial approach used SciPy’s ODE solvers, but that became too slow for 2D/3D problems with fine discretizations. I decided to try Numba and numbakit-ode, but unfortunately, my code ended up running even slower. I’m new to Numba and suspect it’s an implementation issue on my end. I’d appreciate any guidance or suggestions on how to get better performance out of Numba (and numbakit-ode).

Here is my code:

import cmath
import numpy as np
import numba as nb
import nbkode

###########
FFT and IFFT
###########
@nb.jit
def ilog2(n):
    result = -1
    if n < 0:
        n = -n
    while n > 0:
        n >>= 1
        result += 1
    return result

@nb.njit(fastmath=True)
def reverse_bits(val, width):
    result = 0
    for _ in range(width):
        result = (result << 1) | (val & 1)
        val >>= 1
    return result

@nb.njit(fastmath=True)
def fft_1d_radix2_rbi(arr, direct=True):
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    levels = ilog2(n)
    e_arr = np.empty_like(arr)
    coeff = (-2j if direct else 2j) * cmath.pi / n
    for i in range(n):
        e_arr[i] = cmath.exp(coeff * i)
    result = np.empty_like(arr)
    for i in range(n):
        result[i] = arr[reverse_bits(i, levels)]
    # Radix-2 decimation-in-time FFT
    size = 2
    while size <= n:
        half_size = size // 2
        step = n // size
        for i in range(0, n, size):
            k = 0
            for j in range(i, i + half_size):
                temp = result[j + half_size] * e_arr[k]
                result[j + half_size] = result[j] - temp
                result[j] += temp
                k += step
        size *= 2
    return result

@nb.njit(fastmath=True)
def fft_1d_arb(arr, fft_1d_r2=fft_1d_radix2_rbi):
    """1D FFT for arbitrary inputs using chirp z-transform"""
    arr = np.asarray(arr, dtype=np.complex128)
    n = len(arr)
    m = 1 << (ilog2(n) + 2)
    e_arr = np.empty(n, dtype=np.complex128)
    for i in range(n):
        e_arr[i] = cmath.exp(-1j * cmath.pi * (i * i) / n)
    result = np.zeros(m, dtype=np.complex128)
    result[:n] = arr * e_arr
    coeff = np.zeros_like(result)
    coeff[:n] = e_arr.conjugate()
    coeff[-n + 1:] = e_arr[:0:-1].conjugate()
    return fft_convolve(result, coeff, fft_1d_r2)[:n] * e_arr / m

@nb.njit(fastmath=True)
def fft_convolve(a_arr, b_arr, fft_1d_r2=fft_1d_radix2_rbi):
    return fft_1d_r2(fft_1d_r2(a_arr) * fft_1d_r2(b_arr), False)

@nb.njit(fastmath=True)
def fft_1d(arr):
    # Force arr to be 1D
    arr_1d = np.ravel(arr)  # or arr.flatten()
    n = arr_1d.size

    if not n & (n - 1):
        return fft_1d_radix2_rbi(arr_1d)
    else:
        return fft_1d_arb(arr_1d)

@nb.njit(fastmath=True)
def ifft_1d(arr):
    arr_1d = np.ravel(arr)
    n = arr_1d.size
    arr_conj = np.conjugate(arr_1d)
    tmp = fft_1d(arr_conj)
    return np.conjugate(tmp) / n

##########
Equation
##########
def fourier_initial_condition(x, **kwargs):
    N = kwargs.get('N', 10)
    A_range = kwargs.get('A_range', [-0.5, 0.5])
    l_range = kwargs.get('l_range', [1, 3])
    p_range = kwargs.get('p_range', [0, 2*np.pi])

    ## Sample parameters.
    A = np.random.uniform(low=A_range[0], high=A_range[1], size=N)
    l = np.random.randint(low=l_range[0], high=l_range[1], size=N)
    p = np.random.uniform(low=p_range[0], high=p_range[1], size=N)

    initial_condition = np.sum(A[:, np.newaxis] * np.sin(2 * np.pi * l[:, np.newaxis] * x / (x[-1] - x[0]) + p[:, np.newaxis]), axis=0)

    return initial_condition

class BurgersEquation():
    ## Constants for solving the Burgers' equation.
    SOLVER_METHOD = 'RK23'
    RTOL = 1e-8
    ATOL = 1e-9

    @staticmethod
    def domain_params():
        return {
            'ts': 0.0,
            'te': 10.0,
            'xs': 0,
            'xe': 2*np.pi,
            'dt': 0.2,
            'nx': 2048
        }

    @staticmethod
    def extra_params():
        return (0.1, )

    @staticmethod
    def ic(x):
        return fourier_initial_condition(x, N=20, A_range=[-0.5, 0.5], l_range=[3, 6], p_range=[0, 2*np.pi])

    @staticmethod
    def rhs(t, u, k):
        ## FFT of u.
        u_fft = fft_1d(u)

        ## First derivative of u with respect to x.
        u_x = ifft_1d(1j * k * u_fft).real

        ## Second derivative of u with respect to x.
        u_xx = ifft_1d(-k**2 * u_fft).real

        return -u * u_x + 0.1 * u_xx

    @staticmethod
    def solve(T, X, ic, p, traj=True):
        ## Solve Burgers' equation.
        if not traj:
            T = [T[0], T[-1]]

        ## Wavenumbers.
        k = 2 * np.pi * np.fft.fftfreq(len(X), d=(X[1] - X[0]))

        # sol = solve_ivp(fun=BurgersEquation.rhs,
        solver = nbkode.RungeKutta23(BurgersEquation.rhs,
                       t0=T[0],
                       y0=ic,
                       params=(k,),
                       rtol=BurgersEquation.RTOL,
                       atol=BurgersEquation.ATOL)
        ts, ys = solver.run(T)

        return ys.T

import time

burgers_equation = BurgersEquation()

for nx in [64, 128, 256, 512, 1024]:#, 2048]:
  first = True
  total_time = 0

  ## Set domain.
  T = np.arange(burgers_equation.domain_params()['ts'], burgers_equation.domain_params()['te'] + burgers_equation.domain_params()['dt'], burgers_equation.domain_params()['dt'])
  X = np.linspace(burgers_equation.domain_params()['xs'], burgers_equation.domain_params()['xe'], nx)

  ## Get diffusion parameter and initial condition.
  p = burgers_equation.extra_params()
  for _ in range(10):
    ic = burgers_equation.ic(X)

    ## Solve.
    if first:
      first = False
      u = burgers_equation.solve(T, X, ic, p)
    else:
      start = time.time()
      u = burgers_equation.solve(T, X, ic, p)
      end = time.time()
      total_time += end - start

  print(f'nx: {nx}, time taken: {total_time/10} seconds.')