Limiting number of Numba threads in MPI hybrid program

Hi everyone,

I’m trying to write an hybrid MPI+OpenMP python program that uses the openmp threading layer of Numba and mpi4py/schwimbadd.

As an example, I wrote this code in which I compute the Euclidean pairwise distance between two matrices XA, XB parallelize using numba.njit(parallel=True) and numba.prange.
I then create a pool of N worker that have to compute N different pairwise distances. I would like to make each of the N worker using 4 threads.

I tried with this code

import numba as nb
import numpy as np
import timeit
from numba import config
from schwimmbad import MPIPool
from mpi4py import MPI
import os

config.THREADING_LAYER = 'omp'

def euclidean_cdist(XA, XB):

    n_samples_A, n_samples_B = XA.shape[0], XB.shape[0]
    cdist                    = np.empty((n_samples_A, n_samples_B), dtype=np.float64)

    for i in nb.prange(n_samples_A):
        for j in range(n_samples_B):
            dist = 0.0
            for k in range(XA.shape[1]):
                diff = XA[i, k] - XB[j, k]
                dist += diff * diff
            cdist[i, j] = np.sqrt(dist)
    return cdist

def worker(task):
    # DATA 
    XA = np.random.rand(5000,3)
    XB = np.random.rand(5000,3)
    euclidean_cdist(XA, XB)
    times = timeit.repeat(lambda: euclidean_cdist(XA, XB), number=1, repeat=100)
    std   = np.std(times) * 1000  # Convert to milliseconds
    mean  = np.mean(times) * 1000 # Convert to milliseconds
    return mean, std

if __name__ == "__main__":
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    pool = MPIPool()

    if not pool.is_master():
    tasks = [123+rank for _ in range(size)]  # Create a list of tasks
    results =, tasks)


    if rank == 0:
        # Print the mean and std of KDE profiling
        for i, (mean, std) in enumerate(results):
            print("Task {}: Mean: {:.2f} ms, Std: {:.2f} ms".format(i, mean, std))

However, if I launch it with the command mpirun -n 2 python I got the error
ValueError: The number of threads must be between 1 and 2.

The API Reference says that the argument of set_num_threads must be between 1 and NUMBA_NUM_THREADS, and by default NUMBA_NUM_THREADS is equal to multiprocessing.cpu_count().
So it seems, that cpu_cpunt() see only the 2 cpus, one for each process launched by MPI.

The only workaround I found is to substitute the line nb.set_num_threads(4) with os.environ["NUMBA_NUM_THREADS"]=str(4), but the code does not scale as expected.
The average time from each worker is the same for nthreads = 1,2,4,8 and is the same of the time taken using two threads only.
So it seems that os.environ["NUMBA_NUM_THREADS"]=str(4) does not work, but I have to explicitly export the number of Numba threads in the shell.

Even in this case, no time gaining is obtained in the MPI program, while in a serial version of the program there’s a speed up when increasing the number of exported threads.

Does anyone know how to assign a specific amount of Numba threads to each MPI process?