Hello everyone ,
Between the years I had some time to look in a bit more detail at two issues that have come up a few times in the numba backend of pytensor. (Just for background, pytensor is a fork of aesara, which uses numba to compile computation graphs with autodiff, similar to jax or tensorflow, and is used as backend in pymc). I’d be curious to hear what you think of this here. I hope it makes at least some sense.
The issues are:
- Long compile times. We often have compile times longer than 10s, I’ve even seen a very large model where compilation took 30 min.
- Missed vectorization opportunities. From my experience in C or rust I often expect the loop vectorizer of llvm to transform a loop, but in many cases this doesn’t happen in numba.
They seem rather independent, but I think they might actually have similar causes, which is why I’m combining them here.
Let’s start with compilation time: In pytensor we tend to have a relatively large number of functions, that each do relatively little work, and get combined into one big function that is then exposed to the user. As a numba-only proxy, let’s look at this:
# Make sure we have an empty cache (ipython syntax....)
%env NUMBA_CACHE_DIR=/home/adr/.cache/numba
!rm -rf ~/.cache/numba
import numba
import numpy as np
def make_func(func):
# A small function that calls a previous function and reshapes the result
@numba.njit(no_cpython_wrapper=True)
def foo(x):
return func(x).reshape((3, 3))
return foo
@numba.njit(no_cpython_wrapper=True)
def first(x):
return x
# We build a chain of 50 of those simple function
current = first
for i in range(50):
current = make_func(current)
# Wrap it in a final function with cpython wrapper
@numba.njit
def outer_logp_wrapper(x):
return current(x)
# Compile the function
%time outer_logp_wrapper(np.zeros(9))
# CPU times: user 18.7 s, sys: 53.6 ms, total: 18.8 s
# Wall time: 18.8 s
So compiling this function chain takes about 20 s. What is it doing in all that time? From what I currently understand it does the following:
For each of the 50 functions it creates an llvm module containing the function itself and declarations for the function it’s using, and a CodeLibrary
. That code library is then finalized. During finalization it optimizes each function in the module separately using rewrites provided by llvm, and then links in modules for functions it is using (ie the previous function’s finalized, and thus optimized, module). It then optimizes the whole module. (there are also some other passes, that I think don’t change anything for my purpose here). It then produces code for that linked module.
So this means that by the time it produced the code for the final function, it optimized and produced independent machine code for each of the previous functions. The first function will have run through the llvm optimization pipeline and codegen 50 times (I think )!
Would there be an alternative? I think it should be possible to split finalize
into two parts: One first part (let’s call it pre-finalization), where we just optimize the individual functions, and make sure nothing further can be added to the module. And a second part (post-finalization) that recursively links in all required modules, runs module level optimizations and llvm codegen.
In an act of vandalism I went through the code and broke everything but this example, but made it so that we only link at the very end, and avoid some (I don’t think I got all?) of the intermediate codegen. See code here. This got the compile time of the example down to ~4 s! And for pytensor functions I saw similar improvements (and some segfaults, I sure broke stuff).
So how does this affect the runtime of the compiled function? Surprisingly, I think it often improves it: From what I can tell the approach of running module level optimization on each function-module first, and then again after linking runs into optimization ordering issues, and misses opportunities because of it.
Let’s look at this example (I’m sure there are simpler ones, but still):
%env NUMBA_CACHE_DIR=/home/adr/.cache/numba
!rm -rf ~/.cache/numba
# To see vectorization output of llvm
#import llvmlite.binding as llvm
#llvm.set_option('', '--debug-only=loop-vectorize')
import numba
import numpy as np
@numba.njit(error_model="numpy")
def inner(a1, a2, b1):
return (
np.exp((a1 / a2) / a2),
np.exp(((a1 / a2) * b1) / a2),
)
@numba.njit(error_model="numpy")
def outer_logp_wrapper(N, a1, a2, b1):
out1 = np.empty(N)
out2 = np.empty(N)
for i in range(N):
val1, val2 = inner(a1[i], a2[i], b1[i])
out1[i] = val1
out2[i] = val2
return out1, out2
N = 10_000
a1, a2, b1 = np.random.randn(3, N)
_ = outer_logp_wrapper(N, a1, a2, b1)
%timeit outer_logp_wrapper(N, a1, a2, b1)
The loop in outer_logp_wrapper
doesn’t get vectorized on main numba:
LV: Not vectorizing: Found unvectorizable type %0 = insertelement <2 x double> undef, double %.10.i, i32 0
The reason seems to be that when the module for the inner
function is optimized, this will include the loop-vectorizer, and also the slp-vectorizer passes. The first is supposed to have higher priority, so it is executed first (I think…), but because there is no loop this does nothing. llvm then runs the slp-vectorizer, and this will successfully use some vector instructions in inner
. The modules for outer_logp_wrapper
and the optimized inner
modules are then merged, and the optimizer will run again. It inlines the inner
function, but because that already contains vector instructions, it can no longer vectorize the loop from outer_logp_wrapper
.
If we use the crazy-vandalizm-branch again, that only runs the module level optimization passes on the final linked module, we don’t have that problem: The first module is only optimized using the function passes, and those don’t include slp-vectorize. llvm will then inline the inner
function, and the loop vectorizer is successful. After that it will also run slp-vectorize, which doesn’t have anything left to vectorize in that place.
This then leads to much better performance (at least in this particular case, not sure how well this generalizes to different things):
Sorry for the long post and I hope it is more-or-less understandable.
And by the way: Happy new year!