Hi devs!
I’m a developer of PyTensor, a lazy computational graph library that can be transpiled into Numba.
We’re doing a strong push with our numba backend, with the goal of making it our default backend (the default now is a custom C backend).
Our problem: Numba caching
One of the largest issues still holding us back are compile times. PyTensor genereted functions are usually composed of many small functions, easily hundreds or thousands of them called in succession. Each node in our graph translates roughly to one numpy function (think np.reshape(x), or np.transpose(x)), with some exceptions for fused kernels.
We found that Numba compile times are actually not much different than our C backend, but caching was just lacking. It’s not uncommon for our users to recompile the same function many times, or small variations of a function where many nodes are identical (i.e., a 2d transpose is always the same), so caching can really amortize compile time, even during the first runtime of a script.
What was failing
We had many cases where numba caching just failed.
- Many of our functions, including the outer one that coordinates the calls for each node are string generated, and numba simply fails to cache these.
- Even non-string generated functions would often fail to cache, whenever they referenced other njit functions created in a function closure. I documented this in Caching redefined functions. One work-around was to use
register_jitablefor those inner functions. - Not all functions are safe to cache. We sometimes have nested functions. Say an Op that represents a loop over an inner function, or our outer function. Not all functions can be cached safely by numba (say those with Cython pointer functions), and therefore higher order functions that use those functions can’t be cached either. Numba has no mechanism to automatically infer cacheability of an outer function from cacheability of the inner functions.
- OTOH cache invalidation can’t be trusted always spot changes. The numba docs mention several limitations.
Our working-solution
We implemented a working solution in #1637. The general strategy is:
- When we produce a function to-be-jitted internally we also create a unique hash key that (together with numba’s type dispatching) should be enough to uniquely identify it. When a function can’t be safely cached we return None for the key.
- We wrote a custom numba CacheLocator for our generated functions. Everytime we define a function to-be-jitted we store it in a weakref dict, with the hash as the value. The cache locator later simply checks if the python function is in that dictionary, and uses the hash key as the disambiguator for Numba.
- This wasn’t enough because Numba would still do it’s own cache validation (or fail to). For instance problem 2 from above, with fresh njit functions in the closure contents. We bypassed it by hoisting each function to the global scope of a string generated identity function. Numba seems to just trust/ not be able to probe stuff in the global scope.
Looking for feedback
We can now perfectly cache all cacheable functions generated by PyTensor and compile times are quite comparable with our C-backend. We also have some plan for how to handle cython pointer functions: cache_numba_func_ptr.ipynb · GitHub which should allow us to cache everything (except for user defined extensions ofc).
However this all feels a bit hacking and specially point 3, may rely on undefined behavior of numba caching. We don’t want it to break if numba goes in a different direction.
We also noted some slowdown in our benchmarked function when we have cache on / off, which we think must be due to the extra identity functions in front of each real function from point 3.
This overly long post is a kind request for feedback. Do you spot anything that could be simplified / done differently?
Best regards!