I have a few njit functions that are being used inside of the STUMPY Python package and I am noticing that it takes around 3-4 seconds to compile the function in the following example:
$ conda install -c conda-forge stumpy
import numpy as np
import stumpy
import time
if __name__ == "__main__":
T = np.random.rand(100)
m = 50
start = time.time()
mp = stumpy.stump(T, m)
print(time.time() - start) # Around 3.94 s compilation + execution time
start = time.time()
mp = stumpy.stump(T, m)
print(time.time() - start) # Around 0.001s execution time only
The relevant njit code is documented here. I was wondering if there is any advice or general tips on how to reduce this compilation time, perhaps, in providing better hints to the compiler? Are there any Numba tools that can help me diagnose where the compilation hot spots are? Any help or guidance would be greatly appreciated!
And it looks like the 14_nopython_type_inference and 22_nopython_backend is where we are suffering the most and explains the 3-4 second compile time. Blindly, I tried declaring an explicit function signature to see if it had any affect on the nopython_type_inference but, unsurprisingly, that did not help. Unfortunately, the link above did not provide any concrete suggestions for reducing the time for these hotspots.
Why don’t you use caching cache=True? This reduces the first call from 10s to 0.1s (Of course at the second call of your example script. The timing of about 0.1s also seems also to not depend much on the number or size of compiled functions. In very simple examples it might be also in this range, but also for a lot more complex functions (your package).
Using Numba without any caching is quite equivalent to using a rather large C program and compiling it before each execution.
Thanks, @max9111. So, I am aware of cache=True and understand that all subsequent calls to that njit function (beyond the first time) would always be fast even across multiple Python sessions but I think that this should be considered independently of finding ways to reduce the compilation time (i.e., they are not mutually exclusive and I want to/am open to doing both). No matter what, the first call of an njit function will incur some compilation time and I accept this. Since this is a Python package that users download/install, I’m wondering if there’s a way to sneak this compilation time into, say, __init__.py (by calling the stumpy.stump() in there) or maybe, alternatively, by somehow hiding this compilation time inside of import stumpy? Any thoughts or ideas would be welcomed!
A far from perfect, but working solution including caching might be to determine the caching path in the __init__.py.
If you have that you can determine if there are some *.nbc or *.nbi files.
If you don’t find any you could inform the user that on the first time of using your package, a compilation process has to be performed.
After this you can call your main functions with different small examples (different input datatypes, or by directly providing signatures) and inform the user which signatures are compiled.
Of course improving the compilation time itself would also be good…
Numba 0.53+ has an Event API for capturing compilation time information, both in Numba’s compilation pipeline and in LLVM. There’s a demonstration notebook here:
(you can see this live in binder from here: Binder, the notebook is called Numba_053_profiling_the_compiler.ipynb)
Do you know a “clean” method to determine if there are cached compilation results (and for which signatures) are available for a njit decorated function?
@stuartarchibald or someone else probably knows a cleaner way, but for a particular nbi file I used something like this while investigating for this issue. Most of it is cribbed out of caching.py.
for nbi in args.nbi_info:
if not nbi.endswith('.nbi'):
logger.warning(f"skipping filename that doesn't end with 'nbi' -->'{nbi}'")
continue
import io
with io.BytesIO(cu.super_slurp(nbi, audit_logr=logger)) as f:
version = pickle.load(f)
data = f.read()
if version != numba.__version__:
logger.warning(
f"skipping, program numba version '{numba.__version__}' doesn't match index"
f" file numba version '{version}'"
)
continue
stamp, overloads = pickle.loads(data)
print(f'\nindex={nbi}')
print(f'numba_version={version}')
print(f'stamp={stamp}')
for overload, fname in overloads.items():
print(f'overload_fname={fname}')
sig, flags, stuff = overload
print(f'overload_key_signature={sig}')
print(f'overload_key_magic_tuple={flags}')
print(f'overload_key_code_hash={stuff}')
@max911 I’m not sure there is, it sounds like this could be useful though? Do you mean something like some_jitted_function.has_cached(some_signature)? There is a .stats property on the jitted functions which has the cache path and hits/misses, but this is only updated on invocation for given types.