Parallelizing execution of numba-typed functions

I am trying to speed up a solver framework for some physics problems which has to be solved in large quantities.

I managed to obtain a certain level of abstraction using numba types, where one can define functions and solvers using static signatures, aiming a bit for robustness but most importantly for eager compilation to avoid a large overhead every time something is imported.

This is a dummy version of what I am trying to achieve:

import multiprocessing
import time

from enum import IntEnum
from typing import Callable

import numpy as np
# Doing this to get the typed subpackage without importing some spurious stub file.
import numba
import numba.typed
import numba as nb


class Spec(IntEnum):
    foo = 0
    bar = 1


### Creating templates using cfunc to obtain correct signatures.
# This enables me to define f, df and the solver "dynamically" and the multi-solvers will recognize them.

@nb.cfunc(nb.f8[:](nb.f8[:]), cache=True)
def f_dummy(x: np.ndarray) -> np.ndarray:
    return x


@nb.cfunc(nb.f8[:, :](nb.f8[:]), cache=True)
def df_dummy(x: np.ndarray) -> np.ndarray:
    return np.eye(x.size)


params: dict[str, float] = numba.typed.Dict.empty(
    key_type=numba.types.unicode_type, value_type=numba.types.float64
)
params["a"] = 0.0

SOLVER_SIGNATURE = numba.types.Tuple((numba.f8[:], numba.int_, numba.int_))(
    numba.f8[:],
    nb.typeof(f_dummy),
    nb.typeof(df_dummy),
    nb.typeof(params),
    nb.types.IntEnumMember(Spec, nb.int_),
)


@nb.cfunc(SOLVER_SIGNATURE, cache=True)
def solver_dummy(
    x: np.ndarray,
    F: Callable[[np.ndarray], np.ndarray],
    DF: Callable[[np.ndarray], np.ndarray],
    solver_params: dict[str, float],
    spec: Spec,
) -> tuple[np.ndarray, int, int]:
    return x, 1, 2


### Sequential and parallel execution of a solver.

MULTI_SOLVER_SIGNATURE = numba.types.Tuple(
    (numba.f8[:, ::1], numba.int_[::1], numba.int_[::1])
)(
    numba.f8[:, :],
    nb.typeof(f_dummy),
    nb.typeof(df_dummy),
    nb.typeof(solver_dummy),
    nb.typeof(params),
    nb.types.IntEnumMember(Spec, nb.int_),
)

@nb.njit(MULTI_SOLVER_SIGNATURE, cache=True)
def sequential_solver(
    X0: np.ndarray,
    F: Callable[[np.ndarray], np.ndarray],
    DF: Callable[[np.ndarray], np.ndarray],
    solver: Callable[
        [
            np.ndarray,
            Callable[[np.ndarray], np.ndarray],
            Callable[[np.ndarray], np.ndarray],
            dict[str, float],
            Spec,
        ],
        tuple[np.ndarray, int, int],
    ],
    solver_params: dict[str, float],
    spec: Spec,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:

    n = X0.shape[0]
    result = np.zeros_like(X0)
    num_iter = np.zeros(n, dtype=np.int_)
    exitcodes = np.ones(n, dtype=np.int_) * 5

    for i in range(n):
        try:
            res_i, e_i, n_i = solver(X0[i], F, DF, solver_params, spec)
        except Exception:
            exitcodes[i] = 5
            num_iter[i] = -1
            result[i, :] = np.nan
        else:
            exitcodes[i] = e_i
            num_iter[i] = n_i
            result[i] = res_i

    return result, exitcodes, num_iter


@nb.njit(MULTI_SOLVER_SIGNATURE, cache=True, parallel=True, nogil=True)
def parallel_solver(
    X0: np.ndarray,
    F: Callable[[np.ndarray], np.ndarray],
    DF: Callable[[np.ndarray], np.ndarray],
    solver: Callable[
        [
            np.ndarray,
            Callable[[np.ndarray], np.ndarray],
            Callable[[np.ndarray], np.ndarray],
            dict[str, float],
            Spec,
        ],
        tuple[np.ndarray, int, int],
    ],
    solver_params: dict[str, float],
    spec: Spec,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    n = X0.shape[0]
    result = np.empty_like(X0)
    num_iter = np.empty(n, dtype=np.int_)
    exitcodes = np.ones(n, dtype=np.int_) * 5

    # try:
    for i in numba.prange(n):
        res_i, e_i, n_i = solver(X0[i], F, DF, solver_params, spec)
        exitcodes[i] = e_i
        num_iter[i] = n_i
        result[i] = res_i
    # except:
    #     print('Parallel solver threw exception, falling back to sequential solver.')
    #     return sequential_solver(X0, F, DF, solver, solver_params, spec)

    return result, exitcodes, num_iter


### Concrete functions and solvers.

@nb.njit(nb.f8[:](nb.f8[:]))
def f(x):
    return x


@nb.njit(nb.f8[:, :](nb.f8[:]))
def df(x):
    return np.eye(x.size)


@nb.njit(SOLVER_SIGNATURE, cache=True)
def solver(X0, F, DF, params, spec):
    a = params["a"]
    for i in range(200):
        if spec == Spec.foo:
            fi = F(X0) + a
        elif spec == Spec.bar:
            fi = F(X0) - a
        else:
            raise ValueError(f"Unknown specification {spec}")
        dfi = DF(X0)
    return X0, 1, 2


# Testing

N = 1000000
X = np.random.random((N, 10))


print("Starting sequential solver")
s_seq = time.time()
sequential_solver(X, f, df, solver, params, Spec.foo)
s_seq = time.time() - s_seq
print("Starting parallel solver")
s_par = time.time()
parallel_solver(X, f, df, solver, params, Spec.bar)
s_par = time.time() - s_par

print(f"Sequential time: {s_seq:.4f} s")
print(f"Parallel time: {s_par:.4f} s")
print(f"Speed-up: {s_seq / s_par:.2f}x")
print("Number of cores:", multiprocessing.cpu_count())
print("---eof---")

All works as intended. My issue is that the speed-up obtained by parallelization is a bit disappointing. I do not reach a significant speed-up reflecting the number of available CPU cores, while the OS performance monitor indicates full CPU load.

On my machine, I get a speed-up between 2x and 6x (different f, df and solver), while having 16 cores.

My questions:

  1. Am I missing some obvious limitation?
  2. Why is the speed-up so volatile?

Some things to mention:

  • I am using cache=True to indicate the part which I want to be available. f and df ( no caching) indicate a function created dynamically.
  • The performance degradation I noticed recently, only some time after I updated numba and numpy. Unfortunately I forgot which version I was originally on.
  • Previously, the try-except-clause inside the parallel_solver did not cause a total performance loss. But now, the parallel execution is as slow as the sequential one.
  • parallel_solver.parallel_diagnostic() throws an error where something which is not supposed to be None gets passed (that issue is already recognized on GH).

Versions:
numba 0.61.0
numpy: 2.1.3
OS: Ubuntu and Windows