Reduce Compilation Time

Hi Numba Community!

I have a few njit functions that are being used inside of the STUMPY Python package and I am noticing that it takes around 3-4 seconds to compile the function in the following example:

$ conda install -c conda-forge stumpy
import numpy as np
import stumpy
import time


if __name__ == "__main__":
    T = np.random.rand(100)
    m = 50

    start = time.time()
    mp = stumpy.stump(T, m)
    print(time.time() - start)  # Around 3.94 s compilation + execution time

    start = time.time()
    mp = stumpy.stump(T, m)
    print(time.time() - start)  # Around 0.001s execution time only

The relevant njit code is documented here. I was wondering if there is any advice or general tips on how to reduce this compilation time, perhaps, in providing better hints to the compiler? Are there any Numba tools that can help me diagnose where the compilation hot spots are? Any help or guidance would be greatly appreciated!

I came across this post by @stuartarchibald that describes how to determine where compile time is spent in the njit function:

Timings:

0_translate_bytecode                    :0.000001            0.027676            0.000001
1_fixup_args                            :0.000001            0.000001            0.000000
2_ir_processing                         :0.000001            0.001911            0.000001
3_with_lifting                          :0.000001            0.002063            0.000001
4_inline_closure_likes                  :0.000001            0.005272            0.000001
5_rewrite_semantic_constants            :0.000001            0.000131            0.000000
6_dead_branch_prune                     :0.000000            0.000118            0.000000
7_generic_rewrites                      :0.000000            0.009968            0.000001
8_make_function_op_code_to_jit_function :0.000001            0.000052            0.000000
9_inline_inlinables                     :0.000001            0.000161            0.000001
10_dead_branch_prune                    :0.000000            0.000108            0.000000
11_find_literally                       :0.000000            0.000098            0.000000
12_literal_unroll                       :0.000000            0.000063            0.000000
13_reconstruct_ssa                      :0.000000            0.005832            0.000001
14_nopython_type_inference              :0.000001            1.144870            0.000001
15_annotate_types                       :0.000001            0.003121            0.000001
16_strip_phis                           :0.000001            0.002152            0.000001
17_inline_overloads                     :0.000001            0.001972            0.000001
18_pre_parfor_pass                      :0.000001            0.000687            0.000001
19_nopython_rewrites                    :0.000000            0.007209            0.000001
20_parfor_pass                          :0.000001            0.501806            0.000002
21_ir_legalization                      :0.000001            0.005024            0.000001
22_nopython_backend                     :0.000000            2.010890            0.000001
23_dump_parfor_diagnostics              :0.000001            0.000005            0.000000

And it looks like the 14_nopython_type_inference and 22_nopython_backend is where we are suffering the most and explains the 3-4 second compile time. Blindly, I tried declaring an explicit function signature to see if it had any affect on the nopython_type_inference but, unsurprisingly, that did not help. Unfortunately, the link above did not provide any concrete suggestions for reducing the time for these hotspots.

Why don’t you use caching cache=True? This reduces the first call from 10s to 0.1s (Of course at the second call of your example script. The timing of about 0.1s also seems also to not depend much on the number or size of compiled functions. In very simple examples it might be also in this range, but also for a lot more complex functions (your package).

Using Numba without any caching is quite equivalent to using a rather large C program and compiling it before each execution.

Thanks, @max9111. So, I am aware of cache=True and understand that all subsequent calls to that njit function (beyond the first time) would always be fast even across multiple Python sessions but I think that this should be considered independently of finding ways to reduce the compilation time (i.e., they are not mutually exclusive and I want to/am open to doing both). No matter what, the first call of an njit function will incur some compilation time and I accept this. Since this is a Python package that users download/install, I’m wondering if there’s a way to sneak this compilation time into, say, __init__.py (by calling the stumpy.stump() in there) or maybe, alternatively, by somehow hiding this compilation time inside of import stumpy? Any thoughts or ideas would be welcomed!

A far from perfect, but working solution including caching might be to determine the caching path in the __init__.py.
If you have that you can determine if there are some *.nbc or *.nbi files.

If you don’t find any you could inform the user that on the first time of using your package, a compilation process has to be performed.
After this you can call your main functions with different small examples (different input datatypes, or by directly providing signatures) and inform the user which signatures are compiled.

Of course improving the compilation time itself would also be good…

1 Like

Numba 0.53+ has an Event API for capturing compilation time information, both in Numba’s compilation pipeline and in LLVM. There’s a demonstration notebook here:

(you can see this live in binder from here: Binder, the notebook is called Numba_053_profiling_the_compiler.ipynb)

The docs on the API are here:
https://numba.readthedocs.io/en/stable/developer/event_api.html

Hope this helps?

1 Like

@stuartarchibald

Do you know a “clean” method to determine if there are cached compilation results (and for which signatures) are available for a njit decorated function?

@stuartarchibald or someone else probably knows a cleaner way, but for a particular nbi file I used something like this while investigating for this issue. Most of it is cribbed out of caching.py.

    for nbi in args.nbi_info:
        if not nbi.endswith('.nbi'):
            logger.warning(f"skipping filename that doesn't end with 'nbi' -->'{nbi}'")
            continue

        import io
        with io.BytesIO(cu.super_slurp(nbi, audit_logr=logger)) as f:
            version = pickle.load(f)
            data = f.read()
        if version != numba.__version__:
            logger.warning(
                f"skipping, program numba version '{numba.__version__}' doesn't match index"
                f" file numba version '{version}'"
            )
            continue
        stamp, overloads = pickle.loads(data)

        print(f'\nindex={nbi}')
        print(f'numba_version={version}')
        print(f'stamp={stamp}')
        for overload, fname in overloads.items():
            print(f'overload_fname={fname}')
            sig, flags, stuff = overload
            print(f'overload_key_signature={sig}')
            print(f'overload_key_magic_tuple={flags}')
            print(f'overload_key_code_hash={stuff}')

@max911 I’m not sure there is, it sounds like this could be useful though? Do you mean something like some_jitted_function.has_cached(some_signature)? There is a .stats property on the jitted functions which has the cache path and hits/misses, but this is only updated on invocation for given types.