Understanding Initial Execution Delay in Numba Cached Functions

I have a question about the cache loading mechanism in Numba.

Here’s a piece of test code I wrote:

import numba as nb

@nb.njit(cache=True)
def test_jit():
    a = 0
    for i in range(1000000):
        a += i
    return a

If I execute this code from scratch, the execution time is 12.8 ms:

%timeit -r 1 -n 1 test_jit()
12.8 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

On the second execution, the time drops to 2.2 μs, which is understandable since the function has been JIT-compiled into native code:

%timeit -r 1 -n 1 test_jit()
2.2 μs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Next, I closed the Python process and reran it, but this time I used CPUDispatcher.compile to load the cache before calling test_jit. This time, the first execution took 2.74 ms:

%timeit -r 1 -n 1 test_jit()
2.74 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

On the second execution, the time dropped again to 1.8 μs:

%timeit -r 1 -n 1 test_jit()
1.8 μs ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

My questions are:

  1. After loading the cache, why does the first execution of test_jit still take around 2.74 ms instead of something closer to 2 μs? What exactly is happening during this 2.74 ms, and which parts of the code are being executed?

  2. What additional work do I need to do so that after I load the cache, the execution time on the first run is already around 2 μs?

Thanks a lot.

Does this thread provide some insight: Numba Warm-up Speed With Cached Functions - #4 by stuartarchibald ?

Thank you for your reply.

My situation differs slightly from the one discussed in this thread. I loaded cached code as follows:

def njit_func(*args, **kws):
    disp = org_njit(*args, **kws)
    if type(disp) is types.FunctionType:
        def disp_fn(fn):
            d = disp(fn)
            if 'cache' in kws and d._can_compile:
                overloads = _load_index(d._cache._cache_file)
                for tp_k, v in overloads.items():
                    sig, _, _ = tp_k
                    d.compile(sig)
            return d

        return disp_fn
    return disp

However, when the test_jit function is executed for the first time, it’s still significantly slower than on subsequent runs.

I’m finding it hard to picture the exact scenario - can you post the complete code you’re running please?

import numba as nb
import types

org_njit = nb.njit


def _load_index(self):
    import pickle
    try:
        with open(self._index_path, "rb") as f:
            version = pickle.load(f)
            data = f.read()
    except FileNotFoundError:
        # Index doesn't exist yet?
        return {}
    if version != self._version:
        # This is another version.  Avoid trying to unpickling the
        # rest of the stream, as that may fail.
        return {}
    stamp, overloads = pickle.loads(data)
    return overloads


def njit_func(*args, **kws):
    disp = org_njit(*args, **kws)
    if type(disp) is types.FunctionType:
        def disp_fn(fn):
            d = disp(fn)
            if 'cache' in kws and d._can_compile:
                overloads = _load_index(d._cache._cache_file)
                for tp_k, v in overloads.items():
                    sig, _, _ = tp_k
                    d.compile(sig)
            return d

        return disp_fn
    return disp


nb.njit = njit_func


@nb.njit(cache=True)
def test_jit(utc: int):
    a = 0
    for i in range(utc):
        a += i
    return a


if __name__ == "__main__":
    test_jit(100)
    test_jit(100)

Here is the full code.

If I modify your code to include timing, I’m not seeing such a large penalty for the first run. With

import numba as nb
import types

org_njit = nb.njit


def _load_index(self):
    import pickle
    try:
        with open(self._index_path, "rb") as f:
            version = pickle.load(f)
            data = f.read()
    except FileNotFoundError:
        # Index doesn't exist yet?
        return {}
    if version != self._version:
        # This is another version.  Avoid trying to unpickling the
        # rest of the stream, as that may fail.
        return {}
    stamp, overloads = pickle.loads(data)
    return overloads


def njit_func(*args, **kws):
    disp = org_njit(*args, **kws)
    if type(disp) is types.FunctionType:
        def disp_fn(fn):
            d = disp(fn)
            if 'cache' in kws and d._can_compile:
                overloads = _load_index(d._cache._cache_file)
                for tp_k, v in overloads.items():
                    sig, _, _ = tp_k
                    d.compile(sig)
            return d

        return disp_fn
    return disp


nb.njit = njit_func

utc = 100


@nb.njit(cache=True)
def test_jit():
    a = 0
    for i in range(utc):
        a += i
    return a


if __name__ == "__main__":
    import timeit
    first = timeit.timeit(test_jit, number=1)
    second = timeit.timeit(test_jit, number=1)
    print(first)
    print(second)

I get:

$ python preload.py 
1.3259996194392443e-06
3.5999983083456755e-07

The slowdown seems to be around 3.7x for the first call, which feels in line with my expectations for the first execution of a piece of code that is probably optimised to “return a constant” by the compilation pipeline.

Do you still see a much larger time for the first run with the above example executed in Python (and not IPython or a Jupyter notebook)?

I reran it, and the result is similar to yours. The speed of the second run is 3-4 times faster than the first run, so I want to make the first run faster.

I came across this response on the Numba forum: ([Numba Warm-up Speed With Cached Functions]). It explains that, even with cached functions, there is an initial overhead due to type checking, signature matching, and loading binary data. Once these steps are completed, subsequent executions with the same type signature become much faster.

Functions that are cached on disk still have a cost. First types have to be checked to work out the type signature the function is being called with. Then the data on disk has to be loaded and checked to see if there’s a suitable cached version (check if signature, CPU and some other things match) and then the binary data which is the compiled function has to be wired in so that it can be executed. Once this is done, subsequent executions with the same types (type signature) will be much quicker as they are mostly just some type checking and dictionary look ups to get to the point of running the compiled function.

What steps can I take to ensure that native code is generated for optimized performance?

Thank you very much.

Native code is generated for optimized performance. Normally the execution time of the first run is irrelevant (up to a reasonable point), since after the first run there are many subsequent runs. Why is the execution time of the first run of particular concern?