Feedback on custom caching strategy in PyTensor

Hi devs!

I’m a developer of PyTensor, a lazy computational graph library that can be transpiled into Numba.

We’re doing a strong push with our numba backend, with the goal of making it our default backend (the default now is a custom C backend).

Our problem: Numba caching

One of the largest issues still holding us back are compile times. PyTensor genereted functions are usually composed of many small functions, easily hundreds or thousands of them called in succession. Each node in our graph translates roughly to one numpy function (think np.reshape(x), or np.transpose(x)), with some exceptions for fused kernels.

We found that Numba compile times are actually not much different than our C backend, but caching was just lacking. It’s not uncommon for our users to recompile the same function many times, or small variations of a function where many nodes are identical (i.e., a 2d transpose is always the same), so caching can really amortize compile time, even during the first runtime of a script.

What was failing

We had many cases where numba caching just failed.

  1. Many of our functions, including the outer one that coordinates the calls for each node are string generated, and numba simply fails to cache these.
  2. Even non-string generated functions would often fail to cache, whenever they referenced other njit functions created in a function closure. I documented this in Caching redefined functions. One work-around was to use register_jitable for those inner functions.
  3. Not all functions are safe to cache. We sometimes have nested functions. Say an Op that represents a loop over an inner function, or our outer function. Not all functions can be cached safely by numba (say those with Cython pointer functions), and therefore higher order functions that use those functions can’t be cached either. Numba has no mechanism to automatically infer cacheability of an outer function from cacheability of the inner functions.
  4. OTOH cache invalidation can’t be trusted always spot changes. The numba docs mention several limitations.

Our working-solution

We implemented a working solution in #1637. The general strategy is:

  1. When we produce a function to-be-jitted internally we also create a unique hash key that (together with numba’s type dispatching) should be enough to uniquely identify it. When a function can’t be safely cached we return None for the key.
  2. We wrote a custom numba CacheLocator for our generated functions. Everytime we define a function to-be-jitted we store it in a weakref dict, with the hash as the value. The cache locator later simply checks if the python function is in that dictionary, and uses the hash key as the disambiguator for Numba.
  1. This wasn’t enough because Numba would still do it’s own cache validation (or fail to). For instance problem 2 from above, with fresh njit functions in the closure contents. We bypassed it by hoisting each function to the global scope of a string generated identity function. Numba seems to just trust/ not be able to probe stuff in the global scope.

Looking for feedback

We can now perfectly cache all cacheable functions generated by PyTensor and compile times are quite comparable with our C-backend. We also have some plan for how to handle cython pointer functions: cache_numba_func_ptr.ipynb · GitHub which should allow us to cache everything (except for user defined extensions ofc).

However this all feels a bit hacking and specially point 3, may rely on undefined behavior of numba caching. We don’t want it to break if numba goes in a different direction.

We also noted some slowdown in our benchmarked function when we have cache on / off, which we think must be due to the extra identity functions in front of each real function from point 3.

This overly long post is a kind request for feedback. Do you spot anything that could be simplified / done differently?

Best regards!

We cache string generated numba functions in Bodo. Search for CacheLocator in Bodo repo in bodo-ai github org. Ping me if you have questions.

@DrTodd13 thanks for the pointer. Seems similar to what we’re doing except you don’t seem to be concerned with nested functions? Unless inner functions have unique names, the hash is based only on the string contents of the outer function right?

Having the dict in the class scope is nicer, although I think a weakref dict is still better for us.

Can you provide some pseudo-code of the case you are concerned about? I’m a bit confused what you mean by nested functions in this context. The hash is based only on the string contents of the function.

@njit
def inner(x):
  return x * 2

@njit
def outer(x):
  return inner(x) + 1

If outer is string generated, you can’t know if the cache is still valid by checking just its visual representation. inner may have changed. In our case we make sure the hash of outer incorporates the hash of inner.

This may have no correspondence to your use cases.

From this inner / outer example, the way I see it, the problem is that either the inner function’s instructions are just inlined into the outer (which will always happen for simple functions, like in your example), or the entire inner’s code (and everything it needs in its own turn, and so on) will be copy-pasted into the outer’s code (LLVM IR, and then machine code), and subsequently deposited into the outer’s cache. So the problem of outer potentially ending up with a stale cache can be observed even if none of the functions have been string-generated.

A way to solve this problem could be to swap JIT’d functions (such as, inner) that are intended for being called by other JIT’d functions (such as, outer) with pure declarations.

Consider this:

inner.py

from numba import float64, njit
from numbox.core.proxy.proxy import proxy

inner_sig = float64(float64)

factor = 2

# @proxy(inner_sig, jit_options={"cache": True})
@njit(cache=True)
def inner(x):
return x * factor

run.py

from inspect import getfile, getmodule
from numba import njit # noqa: F401
from .inner import factor, inner # noqa: F401

def _anchor():
pass

def create_outer():
outer_txt = """
@njit(cache=True)
def outer(x):
return inner(x) + 1
"""
outer_code = compile(outer_txt, getfile(_anchor), mode="exec")
ns = getmodule(_anchor).__dict__
exec(outer_code, ns)
return ns["outer"]

if __name__ == "__main__":
outer = create_outer()
n = 1.41
assert abs(outer(n) - (factor * n + 1)) < 1e-15, outer(n)

I left two options (one of them, commented out) in inner.py: you can either njit- or proxy- decorate. Try njit first, and change the factor in between runs. The outer’s cache at outer.py will become stale and the assert will fail.

On the other hand, if you use proxy and change the factor, then outer gets to keep its cache (as long as its own module, run.py, hasn’t been poked) while inner’s cache gets refreshed if inner itself is updated.

I wrote about what led to proxy here. The source is here.

1 Like