Feedback on custom caching strategy in PyTensor

From this inner / outer example, the way I see it, the problem is that either the inner function’s instructions are just inlined into the outer (which will always happen for simple functions, like in your example), or the entire inner’s code (and everything it needs in its own turn, and so on) will be copy-pasted into the outer’s code (LLVM IR, and then machine code), and subsequently deposited into the outer’s cache. So the problem of outer potentially ending up with a stale cache can be observed even if none of the functions have been string-generated.

A way to solve this problem could be to swap JIT’d functions (such as, inner) that are intended for being called by other JIT’d functions (such as, outer) with pure declarations.

Consider this:

inner.py

from numba import float64, njit
from numbox.core.proxy.proxy import proxy

inner_sig = float64(float64)

factor = 2

# @proxy(inner_sig, jit_options={"cache": True})
@njit(cache=True)
def inner(x):
return x * factor

run.py

from inspect import getfile, getmodule
from numba import njit # noqa: F401
from .inner import factor, inner # noqa: F401

def _anchor():
pass

def create_outer():
outer_txt = """
@njit(cache=True)
def outer(x):
return inner(x) + 1
"""
outer_code = compile(outer_txt, getfile(_anchor), mode="exec")
ns = getmodule(_anchor).__dict__
exec(outer_code, ns)
return ns["outer"]

if __name__ == "__main__":
outer = create_outer()
n = 1.41
assert abs(outer(n) - (factor * n + 1)) < 1e-15, outer(n)

I left two options (one of them, commented out) in inner.py: you can either njit- or proxy- decorate. Try njit first, and change the factor in between runs. The outer’s cache at outer.py will become stale and the assert will fail.

On the other hand, if you use proxy and change the factor, then outer gets to keep its cache (as long as its own module, run.py, hasn’t been poked) while inner’s cache gets refreshed if inner itself is updated.

I wrote about what led to proxy here. The source is here.