Numba JIT becoming slower with List

Using JIT makes my method even slower.

import numpy as onp
import numba
import numba.typed as nbt
import time
import multiprocessing as mp


def f_pyo(mats, exps, masks):
    #
    num_groups = max((len(mats), len(exps), len(masks)))
    results = nbt.List()
    for i in range(num_groups):
        #
        results.append(onp.zeros_like(mats[i]))
    for i in numba.prange(num_groups):
        #
        if masks[i]:
            #
            results[i][:] = onp.linalg.matrix_power(mats[i], onp.sum(exps[i]))
        else:
            results[i][:] = onp.linalg.matrix_power(mats[i], 0)
    return results
f_jit1 = numba.njit(f_pyo, parallel=False)
f_jit2 = numba.njit(f_pyo, parallel=True)


nprng = onp.random.RandomState(47)
A = nbt.List()
exp = nbt.List()
mask = nbt.List()
for _ in range(30):
    #
    size = nprng.randint(300, 500 + 1)
    p = nprng.randint(1, 16 + 1)
    m = nprng.randint(0, 1 + 1)
    A.append(nprng.uniform(0.0, 1.0, (size, size)).astype(onp.float64))
    exp.append(onp.array(p).astype(onp.int64))
    mask.append(onp.array(m).astype(onp.bool8))


f_pyo(A, exp, mask)
start = time.time()
f_pyo(A, exp, mask)
elapsed = time.time() - start
print("PyObject:", elapsed)


f_jit1(A, exp, mask)
start = time.time()
f_jit1(A, exp, mask)
elapsed = time.time() - start
print("JIT     :", elapsed)


assert mp.cpu_count() > 4
numba.set_num_threads(4)
f_jit2(A, exp, mask)
start = time.time()
f_jit2(A, exp, mask)
elapsed = time.time() - start
print("JIT-4   :", elapsed)

Here is resources allocation script by SLURM:

#!/bin/bash

#SBATCH --job-name=benchmark
#SBATCH --output=benchmark.stdout.txt
#SBATCH --error=benchmark.stderr.txt
#SBATCH --partition=...
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:0
#SBATCH --cpus-per-task=5
#SBATCH --gpus-per-task=0

#
python benchmark.py

I got

PyObject: 0.06118154525756836
JIT     : 0.0636894702911377
JIT-4   : 0.09839320182800293

Clearly, JIT is slower, and JIT with paralleling is even slower.
I have tried several times, and the performance are similar.

However, I am not quite sure about why this is happening.
(It seems like that nbt.typed.List, if ... else ... branching and numpy.linalg.matrix_power are all possible reasons.)

Hi,

sorry right away for what is going to be a slightly sketchy answer, I just stumbled over your post and wanted to give a few suggestions:

  1. Avoid lists if you can, at least your exponents and masks can just be stored in a single linear array.
  2. I think that depending on your installation the linalg library already uses multithreading internally. So if you parallelise on top of that, that could mess with thread scheduling?
  3. A significant amount of time is spend memory allocating, at least the runtime was reduced by some 25% if I reuse the same results list over and over again (by passing preallocated arrays as an argument to the function)
  4. If memory is abundant, you could overallocate results and turn it into a single mutlidimensional array. Then you only use the rows and columns you need. (Memory contiguity could also be important here)

Hope some of these can help you solve your performance issues :slight_smile: Best of luck!

Thank you for your advice.
I have first tried to simply my work for benchmarking according to your advice.
But I still think there must be some problem (or the parallel doc is missing some requirement), rather than improper usage.
I steal example from parallel document, put it under aforementioned skeleton and try again.
Again, I find JIT parallel is clearly making things worse.

from numba import njit, prange, set_num_threads
import numpy as onp
import multiprocessing as mp
import time

def f_init(n, d, *args):
    nprng = onp.random.RandomState(42)
    Y = nprng.uniform(0.0, 1.0, (n,))
    X = nprng.uniform(0.0, 1.0, (n, d))
    w = nprng.uniform(0.0, 1.0, (d,))
    return Y, X, w, *args

def f_pyo(Y, X, w, iterations):
    for i in range(iterations):
        w -= onp.dot(((1.0 / (1.0 + onp.exp(-Y * onp.dot(X, w))) - 1.0) * Y), X)
    return w

@njit
def f_jit1(Y, X, w, iterations):
    for i in range(iterations):
        w -= onp.dot(((1.0 / (1.0 + onp.exp(-Y * onp.dot(X, w))) - 1.0) * Y), X)
    return w
# \\ f_jit1 = njit(f_pyo)

@njit(parallel=True)
def f_jitn(Y, X, w, iterations):
    for i in range(iterations):
        w -= onp.dot(((1.0 / (1.0 + onp.exp(-Y * onp.dot(X, w))) - 1.0) * Y), X)
    return w
# \\ f_jitn = njit(f_pyo, parallel=True)

assert mp.cpu_count() > 4
set_num_threads(4)

def timeit(f, name, log, *args):
    supp = f_init(*args)
    time_start = time.time()
    f(*supp)
    time_end = time.time()
    elapsed = time_end - time_start
    if log:
        print("{:>12s}: {:>12s}".format(name, "{:.6f}".format(elapsed)))

timeit(f_pyo, "PyObject", False, 2, 3, 2)
timeit(f_jit1, "JIT", False, 2, 3, 2)
timeit(f_jitn, "JIT-4", False, 2, 3, 2)
# \\ timeit(f_pyo, "PyObject", False, 50, 256, 100)
# \\ timeit(f_jit1, "JIT", False, 50, 256, 100)
# \\ timeit(f_jitn, "JIT-4", False, 50, 256, 100)

for t in range(1, 10 + 1):
    print(t)
    timeit(f_pyo, "PyObject", True, 50, 256, 100)
    timeit(f_jit1, "JIT", True, 50, 256, 100)
    timeit(f_jitn, "JIT-4", True, 50, 256, 100)

Using the same allocation (5 CPUs), and run several times, JIT without parallel is mostly worse, sometimes better (may be within random perturbation), while JIT with parallel is clearly worse.

1
    PyObject:     0.003073
         JIT:     0.018109
       JIT-4:     1.758944
2
    PyObject:     0.003056
         JIT:     0.020413
       JIT-4:     2.116485
3
    PyObject:     0.053066
         JIT:     0.004492
       JIT-4:     1.834853
4
    PyObject:     0.010350
         JIT:     0.017284
       JIT-4:     2.033234
5
    PyObject:     0.010561
         JIT:     0.011928
       JIT-4:     1.907281
6
    PyObject:     0.026609
         JIT:     0.012688
       JIT-4:     1.770410
7
    PyObject:     0.023375
         JIT:     0.008984
       JIT-4:     1.569419
8
    PyObject:     0.049018
         JIT:     0.009801
       JIT-4:     2.246718
9
    PyObject:     0.003440
         JIT:     0.039175
       JIT-4:     1.967270
10
    PyObject:     0.010617
         JIT:     0.009853
       JIT-4:     1.670893

If I do not limit resources (72 CPUs) and run directly, JIT without parallel is always better, while JIT with parallel is still worse.

1
    PyObject:     0.039851
         JIT:     0.009891
       JIT-4:     0.123545
2
    PyObject:     0.018576
         JIT:     0.012999
       JIT-4:     0.117343
3
    PyObject:     0.056102
         JIT:     0.017828
       JIT-4:     0.108394
4
    PyObject:     0.071355
         JIT:     0.020207
       JIT-4:     0.110594
5
    PyObject:     0.133202
         JIT:     0.012661
       JIT-4:     0.131159
6
    PyObject:     0.065640
         JIT:     0.031049
       JIT-4:     0.101577
7
    PyObject:     0.043489
         JIT:     0.007528
       JIT-4:     0.128778
8
    PyObject:     0.065448
         JIT:     0.006988
       JIT-4:     0.129540
9
    PyObject:     0.046704
         JIT:     0.013894
       JIT-4:     0.127429
10
    PyObject:     0.071391
         JIT:     0.007116
       JIT-4:     0.162058

I think that I have excluded List, and memory allocation factors, and it is very likely that the problem is silent NumPy thread confliction.
But if we even conflict with NumPy with such simple operations, is it meaningless to have parallel functionality?

Heja,

sorry I am a bit pressed for time at the moment, so I cannot look at your example in detail and only hurl suggestions your way:

  1. Don’t use time.time for benchmarking. Use something like perf_counter or pythons timeit module in the standard library, they are more appropriate
  2. something about the times you get feels dodgy. There are enormous run time differences for what should (at a first look) be the exact same computation. With some of the timings feeling antiproportional to each other
  3. Without having double checked, I still think your standard python function might still run in parallel, using the heavily optimised numpy (which in turn calls even more optimised LAPACK and BLAS in places iirc).
  4. Worst case your python function then uses all threads while you limit numba to only use 4.
  5. except for the iteration loop your functions only call vectorised numpy expressions. Since they are already optimised and compiled no big speed up is generally expected here (but that is not to say it is okay for numba to be significantly slower, that should not happen I think)
  6. Somewhat contradicting your initial example this function is not embarrassingly parallel. The iterations have to run sequentially in order (as you already intuitively noticed when you did not use prange). The only parallelisation that can happen here is withing the array / linear algebra functions.
  7. I am a bit alarmed by your “import numpy as onp” are you using some more obscure libraries altering numpy or somemthing like jax, (that I know nothing about other than having the feeling that they could be quite invasive?)

Hopefully i can take a closer look later, but for now this is all i can offer. It might be worth having a look around this form concerning benchmarking practices too, there have been some discussions in the past.

Cheers!