Numba Warm-up Speed With Cached Functions

I have the following reproducible example where I have an njit Numba function where cache=True is used:

import numpy as np
from numba import njit, prange
import time


@njit(cache=True, parallel=True, fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _rolling_isconstant(a, w):
    l = a.shape[0] - w + 1
    out = np.empty(l)
    for i in prange(l):
        out[i] = np.ptp(a[i : i + w])

    return np.where(out == 0.0, True, False)


def rolling_isconstant(a, w):
    axis = a.ndim - 1
    return np.apply_along_axis(
        lambda a_row, w: _rolling_isconstant(a_row, w), axis=axis, arr=a, w=w
    )


np.random.seed(0)
T = np.random.rand(1_000)
m = 50

start = time.time()
out = rolling_isconstant(T, m)
print(time.time() - start)

start = time.time()
out = rolling_isconstant(T, m)
print(time.time() - start)

As expected, the first call (includes compilation time) to rolling_isconstant takes approximately 0.6659 seconds, while the second call is much faster and takes only 0.0003 seconds. Additionally, the function should be cached on file now. However, when I execute the same code again, the first call takes 0.1004 seconds and the second call takes 0.0003 seconds. I would’ve expected the first run to be much faster but maybe this is consistent with the time needed for Numba to warm-up (i.e., get loaded) and the second call doesn’t incur this cost?

If so, when a user wants to call the same function multiple times but across multiple independent processes, is there a way to also minimize/reduce this warm-up time for the user?

1 Like

hi @seanlaw I checked the profiler and it’s quite interesting. 76% of the time is spent in importlib. Over 500 calls to import modules. I’m not sure if it’s right or wrong, but that seems to be where the time goes.

is there a reason why the multiprocessing case is important, as opposed to the single process case? In my experience with multiprocessing it’s very hard to avoid repeating some work in each process, and it does not matter anyway because it’s happening in parallel.

Luk

1 Like

Thanks, @luk-f-a! Honestly, I’m not sure what the multi-process use case is except that a STUMPY user asked me to add support for cache=True to all of our njit functions in STUMPY (really, via the wonderful.enable_cache() suggestion that you had provided me a few months ago!). I agree with your assessment that it’s hard to avoid repeating some of that work and I even suggested that the user simply create a RESTful endpoint wrapped around the desired function, which would incur the warm-up time and compilation only once, and data can be sent over to the endpoint via an HTTP POST request. I really only use STUMPY for single long-running processes and so I’m not quite sure how to support the multi-process case in a scalable/meaningful way. Do you happen to have any thoughts on this subject beyond what you’ve mentioned thus far? Alternatively, is there a way to reduce the importlib calls?

Hi @seanlaw, @luk-f-a,

Thanks for profiling this @luk-f-a, I think this makes sense, there’s some cost to getting a new Python process to a point where it can start doing the work requested, which means importing modules etc to get to that state.

Some specific details about Numba…

  1. The Numba runtime (NRT) would also be compiled at the point a Numba @njit decorated function is encountered, this has to be done before Numba can execute anything (it’s being delayed until JIT compilation is requested in Numba PR #8438, which should speed up pure import time, but the NRT compilation still has to occur before execution).
  2. To get Numba to a point where it can compile something, it has to import NumPy (and probably SciPy if it’s in the environment), which takes a while. Then it has to import llvmlite which triggers loading LLVM (which is large), and then it has to initialise LLVM before it compiles anything.
  3. Functions that are cached on disk still have a cost. First types have to be checked to work out the type signature the function is being called with. Then the data on disk has to be loaded and checked to see if there’s a suitable cached version (check if signature, CPU and some other things match) and then the binary data which is the compiled function has to be wired in so that it can be executed. Once this is done, subsequent executions with the same types (type signature) will be much quicker as they are mostly just some type checking and dictionary look ups to get to the point of running the compiled function.
1 Like