Switching between parallel / serial mode

Thanks again for making Numba! I am still a big fan and since my last post I have gotten my entire body covered in Numba-inspired tattooes! (OK, that’s also not true, but it could be!)

My question this time is whether it is possible to somehow switch whether a Numba Jit function will run in parallel or serial mode.

The reason this is important to me, is that I am using a function foo() from a Python package, which indirectly calls another function in the same package called bar(), which has been wrapped in Numba Jit, like this:

from numba import njit, prange

def foo(x):
    # Do something ...
    y = bar(x)
    # Do something ...
    return y
    
@njit(parallel=False)
def bar(x):
    for i in prange(len(x)):
        x[i] += i
    return x

I can call both foo() and bar() from my own code. Sometimes I would like to run bar() in parallel mode and sometimes in serial mode. And I would also like this mode to be used when foo() calls bar().

Is there a way to change the parallel / serial mode of a Jit function after it has been defined?

Perhaps something like this:

# Switch to parallel mode for the bar-function.
bar.parallel = True

# Switch to serial mode for the bar-function.
bar.parallel = False

If this is currently not possible, would it be possible to implement this? Or is there a work-around / hack to make this work?

Thanks!

Numba has a function set_num_threads for setting the number of threads dynamically. The following is a demonstration using 1 core and many cores for computing the Mandelbrot Set.

Demonstration

import numpy as np

from timeit import default_timer as timer
from numba import config, njit, prange, set_num_threads

def create_fractal(min_x, max_x, min_y, max_y, image, iters, parallel=False):
    if parallel:
        set_num_threads(config.NUMBA_NUM_THREADS)
        return _fractal(min_x, max_x, min_y, max_y, image, iters)
    else:
        set_num_threads(1)
        return _fractal(min_x, max_x, min_y, max_y, image, iters)
    

@njit('u1(f8, f8, i4)')
def mandel(x, y, max_iters):
    """
    http://numba.pydata.org/numba-doc/latest/user/examples.html#mandelbrot
    Given the real and imaginary parts of a complex number,
    determine if it is a candidate for membership in the Mandelbrot
    set given a fixed number of iterations.
    """
    i = 0
    c = complex(x,y)
    z = 0.0j

    for i in range(max_iters):
        z = z * z + c
        if (z.real * z.real + z.imag * z.imag) >= 4:
            return i % 256

    return 255


@njit('u1[:,:](f8, f8, f8, f8, u1[:,:], i4)', parallel=True)
def _fractal(min_x, max_x, min_y, max_y, image, iters):
    height, width = image.shape[:2]
    pixel_size_x = (max_x - min_x) / width
    pixel_size_y = (max_y - min_y) / height

    for y in prange(height):
        imag = min_y + y * pixel_size_y
        for x in range(width):
            real = min_x + x * pixel_size_x
            color = mandel(real, imag, iters)
            image[y,x] = color

    return image


if __name__ == '__main__':
    image = np.empty((720, 1280), dtype=np.uint8)
    min_x, max_x = -0.4391276296940602, -0.4391263235442111
    min_y, max_y = 0.5745835354759210, 0.5745828012134179

    s = timer()
    create_fractal(min_x, max_x, min_y, max_y, image, 2000, False)
    print("1: {:.3f} seconds".format(timer() - s))

    s = timer()
    create_fractal(min_x, max_x, min_y, max_y, image, 2000, True)
    print("m: {:.3f} seconds".format(timer() - s))

Results

$ NUMBA_NUM_THREADS=4 python3 mandel.py
1: 1.578 seconds
m: 0.408 seconds

Thanks for the quick and detailed reply!

As you suggested, I have tried switching between parallel and serial mode simply by setting the number of execution threads. It works really well and there is apparently no overhead for the 1-thread serial mode, as the following test shows:

import numpy as np
import numba as nb
from numba import jit, njit, prange

# Original number of Numba threads (in my case 8 with 4 actual CPU cores).
nb_threads = nb.config.NUMBA_NUM_THREADS

# Function that is not jitted.
def foo(x):
    # Do something ...
    y = bar(x)
    # Do something ...
    return y

# Jit parallel version.
@njit(parallel=True)
def bar(x):
    n = len(x)
    y = np.zeros(n)
    
    # Parallel loop.
    for i in prange(n):
        
        # Inner-loop.
        for j in range(i, i + 100):
            # Some "heavy" computation.
            y[i] += np.cos(j + x[i])

    return y

# Jit serial version.
# This takes the underlying Python fuction from the function bar
# and wraps it in Numba Jit again while disabling parallel mode.
bar_ser = jit(bar.py_func, parallel=False)

# Test array.
x = np.arange(10000)

# Ensure the jit functions have been compiled before timing.
bar(x=x)
bar_ser(x=x)

# Timing tests below:

# foo() with 8 threads (4 actual cores).
nb.set_num_threads(nb_threads)
%timeit -n 200 foo(x=x)
# 5.78 ms ± 377 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)

# bar() with 8 threads (4 actual cores).
nb.set_num_threads(nb_threads)
%timeit -n 200 bar(x=x)
# 5.59 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)

# bar() with 1 thread.
nb.set_num_threads(1)
%timeit -n 200 bar(x=x)
# 24.3 ms ± 315 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)

# bar() non-parallel jit version.
%timeit -n 200 bar_ser(x=x)
# 25.7 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 200 loops each)

I have not made a statistical test on the time-difference between bar() with 1 thread and bar_ser() to see if it is a statistically significant difference, but I forced 200 loops in timeit which probably makes the small difference statistically significant. So it looks like bar_ser() might be a tiny bit slower than bar() with 1 thread, which is quite interesting. I would have thought it would be the other way around, as the overhead from the parallel machinery would give a small runtime penalty when only 1 thread was being used.

Switching Jit Functions

We can also switch the original bar() function with one that has different parameters for Numba jit. For example, if the original function had parallel=True as in the example above, but we really wanted it to be parallel=False, we can do the following, which will then be used by the function foo() as well:

# Update the function bar with non-parallel jit version.
bar = jit(bar.py_func, parallel=False)

If instead the function bar() is located in some external package named ping and sub-module pong, then we can replace it as follows. This might have to be run before doing any other imports from the ping package, especially if we import the bar function itself, in order for it to effectively replace the original version of the bar function:

import ping.pong

# Get the original Python function for bar.
py_func = ping.pong.bar.py_func

# Update the bar function in ping.pong with a non-parallel jit version.
ping.pong.bar = jit(py_func, parallel=False)

This is a hacky solution that seems to work, but it is probably considered bad form to replace functions in external packages like this.