Thanks for the quick and detailed reply!
As you suggested, I have tried switching between parallel and serial mode simply by setting the number of execution threads. It works really well and there is apparently no overhead for the 1-thread serial mode, as the following test shows:
import numpy as np
import numba as nb
from numba import jit, njit, prange
# Original number of Numba threads (in my case 8 with 4 actual CPU cores).
nb_threads = nb.config.NUMBA_NUM_THREADS
# Function that is not jitted.
def foo(x):
# Do something ...
y = bar(x)
# Do something ...
return y
# Jit parallel version.
@njit(parallel=True)
def bar(x):
n = len(x)
y = np.zeros(n)
# Parallel loop.
for i in prange(n):
# Inner-loop.
for j in range(i, i + 100):
# Some "heavy" computation.
y[i] += np.cos(j + x[i])
return y
# Jit serial version.
# This takes the underlying Python fuction from the function bar
# and wraps it in Numba Jit again while disabling parallel mode.
bar_ser = jit(bar.py_func, parallel=False)
# Test array.
x = np.arange(10000)
# Ensure the jit functions have been compiled before timing.
bar(x=x)
bar_ser(x=x)
# Timing tests below:
# foo() with 8 threads (4 actual cores).
nb.set_num_threads(nb_threads)
%timeit -n 200 foo(x=x)
# 5.78 ms ± 377 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)
# bar() with 8 threads (4 actual cores).
nb.set_num_threads(nb_threads)
%timeit -n 200 bar(x=x)
# 5.59 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)
# bar() with 1 thread.
nb.set_num_threads(1)
%timeit -n 200 bar(x=x)
# 24.3 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)
# bar() non-parallel jit version.
%timeit -n 200 bar_ser(x=x)
# 25.7 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)
I have not made a statistical test on the time-difference between bar() with 1 thread and bar_ser() to see if it is a statistically significant difference, but I forced 200 loops in timeit which probably makes the small difference statistically significant. So it looks like bar_ser() might be a tiny bit slower than bar() with 1 thread, which is quite interesting. I would have thought it would be the other way around, as the overhead from the parallel machinery would give a small runtime penalty when only 1 thread was being used.
Switching Jit Functions
We can also switch the original bar() function with one that has different parameters for Numba jit. For example, if the original function had parallel=True as in the example above, but we really wanted it to be parallel=False, we can do the following, which will then be used by the function foo() as well:
# Update the function bar with non-parallel jit version.
bar = jit(bar.py_func, parallel=False)
If instead the function bar() is located in some external package named ping and sub-module pong, then we can replace it as follows. This might have to be run before doing any other imports from the ping package, especially if we import the bar function itself, in order for it to effectively replace the original version of the bar function:
import ping.pong
# Get the original Python function for bar.
py_func = ping.pong.bar.py_func
# Update the bar function in ping.pong with a non-parallel jit version.
ping.pong.bar = jit(py_func, parallel=False)
This is a hacky solution that seems to work, but it is probably considered bad form to replace functions in external packages like this.