Best Practices Adopting Numpy Based Code for Compatibility/Performance with Numba

Hello,

this is my tech stack:
Py: 3.10.2
Numpy: 2.0.2 or 1.26.4
Numba: 0.60.0 (latest)

I have a standard Py class with a few class methods that heavily uses numpy. The class does the following:

(1) it receives a set of data points, usually between 1,000-7,000, and a multi-dimensional lookup table (numpy array)
(2) it loops over all points and then runs a minimizer algorithm for each point
(3) an objective function compares the current guessed/estimated input from the minimizer in the lookup table until a tolerance delta threshold is met towards the desired target for the current data point - the max number of lookups per data point is 5,000. The lookup in the lookup table is via trilinear interpolation.

The above can lead to millions of loop cycles and lookups.

I’ve now converted the class over to be numba compatible, starting off with making the class a @jitclass.
After a long adventure (I’ll share details in another post), the code is now finally functional, and it performas faster than standard Py code, but what I found along the way is that certain numpy functionality performs way worse in numba that one would expect - it decreases performance, by a lot. Naturally, I’m looking to optimize my code for best performance.

Here are my questions:

(A) Numpy
What are the best practices when adopting numpy based code for numba, specifically optimizing the code for peak performance?
Which numpy functions should be avoided in numba (hence: rewritten in plain Py) because they don’t perform well?

(B) Loops
What are the best practices for optimization of loops in numba?

(C) Caching
Can I cache the entire @jitclass? if so, every time a new @jitclass instance is created with different class parameters (different set of data points and different lookup table) but data type is obviously always the same, would it create a new cache version or would it re-use the first cache?
Or do I need to cache individual member functions?

(D) Parallelization
What is the best approach to integrate parallelization for this class? Can I run the entire class in parallel, or should I focus on the functions that include loops, or the functions that the loops are calling in each loop cycle?

Thanks!

for the numba devs:
here is some feedback after converting over 1,000+ lines of standard Py code that uses custom types and a lot of numpy:
The #1 error will always be related to type declaration that numba either expects and/or does not recognize/understand - but unfortunately very often the the error messages returned by numba are very inconclusive or impossible to debug.

example 1:
numba throws an error that states non-support for a specific numpy or Py package method, which is often misleading. As it turns out, this is usually a type error, that numba will not recognize or accept, for whatever reason.

example 2:

TypeError: missing a required argument: ‘<arg_name>’

when this error occurs, the stack trace does NOT tell at what line in the Py code the error occurred, it only tells you the original numba jitted function that started the stack trace - but the code error may occur in any of the many sub functions that are called later on - impossible to debug with larger code bases

example 3:
returning data from a try/except block in a function can throw this error:

An unsupported bytecode sequence has been encountered: op_LIST_EXTEND at the start of a block. This could be due to the use of a branch in a tuple unpacking statement.

Highly inconclusive and/or misleading. And numba does not tell at what line in the code base this error occurs, so it is very hard to find. This is usually also a data type related error.

Maybe some of this can be improved for an easier experience adopting code to numba.

Thanks!

It’s difficult to make specific recommendations without example code. However, based on what you’ve shared, here are a few points of advice:

  1. You might consider refactoring your code to eliminate the use of jitclass. Unfortunately functions that take jitclasses cannot be cached (i.e. @njit(cach=True) doesn’t save you any compile time on subsequent runs). You could use structrefs instead, but trust me, I’ve helped a lot of people go that route and it’s just a pain in the butt to get working. Numpy arrays are very fast to box/unbox so jitted functions that take numpy arrays as arguments will be very fast, and are cacheable. If you can, it might be easiest to just leave the object implementation on the Python side (assuming all the relevant members are numpy arrays), and pass those numpy arrays into jitted subroutines as needed.

  2. Numba actually reimplements a lot of numpy by replacing it with a low-level LLVM code that it can optimize aggressively. Not all of those reimplementations are as fast (algorithmically) as numpy. The good news is that if you come up with a way of implementing parts of your algorithm in loops instead of calling numpy’s tensor operations, numba can very often optimize those loops into something much faster than you could achieve with plain old numpy.

If you have further questions it would help if you shared your code, and specific concerns about numpy functions that you think are running too slow.

@DannyWeitekamp
thanks!

it would help if you shared your code

the jitclass is app. 1,000 lines of code, not sure if it makes sense sharing this here
but it basically goes like this:

# create new jitclass instance, pass in data points array and lookup table for this task
jc = My_JITClass(data_points, lookup_table)

# jitclass then runs this logic
for data_pt in data_points:    # <---- process all data points passed to the class
	# a - run some basic point evaluation
	# b - now call minimizer fn for this pt
	# the minimizer fn will run a loop up to 5,000 cycles to try to find a solution for the current data point
	for i in range(5000):
		# a- create new guess
		# b - lookup the current guess in the lookup table using trilinear interpolation
		# --> my_lookup_fn uses the lookup_table array that was stored in class var at class instantiation
		my_lookup_fn()
		# c- if result is below threshold, break loop

Unfortunately functions that take jitclasses cannot be cached

does that mean other functions calling a jitclass cannot be cached or member functions of a jitclass cannot be cached?

You might consider refactoring your code to eliminate the use of jitclass

okay, thoughts about this:
(1) I used jitclass to store the lookup table in a class var, which is then used in the most inner most inner part of the process my_lookup_fn()
if I was to refactor the jitclass into a bunch of single functions, then I would have to pass down the lookup_table array (and other vars) all the way through to my_lookup_fn() - not sure if that would cost more time on potentially millions of loop cycles then the cache would save time… (?)
(2) more importantly, since data_points and lookup_table always change on each usage of the cache, would the cache work with these new input vars or be re-compiled every time?
(3) is running jitted functions faster than executing the same logic via jitclass?

If you have further questions and specific concerns about numpy functions that you think are running too slow

the few numpy functions I identified to be bottlenecks, I all already removed, the ones I have remaining are:
np.asarray
np.array
np.denumerate
np.array_equal
are any of those of concern performance wise?

Do you have any suggestion regarding my question (D) in the OP, is parallelization doable with jitclass, and, given the multi-loop structure of the code above, which part of the code would benefit from running in parallel?

Thanks!

does that mean other functions calling a jitclass cannot be cached or member functions of a jitclass cannot be cached?

Both. self is a jitclass in a member function so the same issue applies. Implementation-wise the issue with jitclasses is that they don’t serialize the same way (using pickle) between executions, so if you try to cache them they fill the cache, but never retrieve the compiled code from it correctly between runs. There is essentially no way to use jitclass at all and cache the jit compilation (it’s a very annoying bug).

(1) I used jitclass to store the lookup table in a class var, which is then used in the most inner most inner part of the process my_lookup_fn()
if I was to refactor the jitclass into a bunch of single functions, then I would have to pass down the lookup_table array (and other vars) all the way through to my_lookup_fn() - not sure if that would cost more time on potentially millions of loop cycles then the cache would save time… (?)

It is hard to say for certain. As always with questions of performance you need to make sure you’re profiling your code (i.e. measure the execution time). This probably won’t change the execution time much since internally it is just passing one more pointer to the function, and if the optimizer decides to inline the inner bits of your loop then it might end up lifting that argument assignment out of the inner parts of the loop anyway.

(2) more importantly, since data_points and lookup_table always change on each usage of the cache, would the cache work with these new input vars or be re-compiled every time?

When you set @jit(cache=true) you are telling numba to write the results of compiling the decorated function into to a file, and then reuse that compiled program the next time you run the decorated function. data_points and lookup_table are inputs to that compiled program, so they are not written into the cache. They can change between each use.

(3) is running jitted functions faster than executing the same logic via jitclass?

It should be about the same. Maybe a tiny bit faster in some cases. They will certainly run considerably faster the first time they are called in your Python script if you write them with @jit(cache=True), since in that case they don’t need to be recompiled on the first call.

the few numpy functions I identified to be bottlenecks, I all already removed, the ones I have remaining are:
np.asarray
np.array
np.denumerate
np.array_equal
are any of those of concern performance wise?

np.asarray and np.array can be issues if you are using them to convert something that is not a numpy array into a numpy array. For instance, you can usually gain some extra speed by removing instances where you make intermediate lists that you turn into numpy arrays. If you know how big the array will be (or the maximum of how big it will be) you can allocate a new array with np.empty() and then fill it in as needed (don’t be afraid to use loops for that, numba will compile it so it is fast). If you don’t use the full array then you can just slice it.

I’ve never used np.denumerate but from what I know about how numba usually implements generators I would expect that your code could be optimized much more aggressively if you replaced calls to np.denumerate with nested loops over the .shape of the array.

I’m not sure about np.array_equal, but you could certainly write your own implementation of it as a @jit function and try to profile the difference.

Do you have any suggestion regarding my question (D) in the OP, is parallelization doable with jitclass, and, given the multi-loop structure of the code above, which part of the code would benefit from running in parallel?

You could try using prange. I can’t promise that it will make your code faster. Not every program is a good candidate for multithreading, but it is very easy to replace range() with prange() and measure whether or not there is an improvement. Usually applying it to the outermost loop is a good idea (assuming the results of each iteration of the loop do not depend on something computed in the previous one).

1 Like

np.asarray and np.array can be issues if you are using them to convert something that is not a numpy array into a numpy array

did you mean something that is not a list?

You could try using prange

So can I use @njit(parallel=True) on member functions of jitclass?
Regarding prange, can I use it in nested loops, such as:

for i in prange(10):
    for j in prange(10):
        for k in prange(10):
           # do something

Thanks!

did you mean something that is not a list?

No that includes lists. Typically lists are implemented with numba.typed.List which is oddly a lot slower than numpy arrays (on all fronts: making them, appending to them, inserting, reading etc.). Again, this is an annoying bug in numba. If you’re going for the fastest possible execution time, avoid making intermediate lists at all. Stick to numpy arrays wherever you can.

Regarding prange, can I use it in nested loops, such as: …

No, you probably only want to use prange in the outermost loop:

for i in prange(10): # probably will only help if much larger than 10
    for j in range(10):
        for k in range(10):
           # do something

prange() spools up several threads that work independently on slices of its range. Starting a new thread has a pretty large overhead cost, so prange is only likely to speed things up if it is splitting up a job that would normally take at least hundreds of milliseconds. However, thread creation times are very system-dependent, so again, be sure you are timing your code’s execution time to test whether prange is helping or hurting.

1 Like

No, you probably only want to use prange in the outermost loop

as it turns out, @jit(parallel=True) is not supported on jitclass member functions :expressionless:

Jitted fucntions
okay, I’ve restructured the code from jitclass to single jitted functions.
Execution time improves just by doing that, which is good.

Parallel mode
Using parallel=True with prange on any of the three loops in my main function that is initially called, increases execution time, not sure why - so there’s no benefit here

Cache mode
this does work in testing and it greatly speeds up execution time on subsequent runs
What is the location of the numba disk cache file?
Will this work in Python single .exe files?

Thanks!!

Glad to hear things are a bit faster now.

What is the location of the numba disk cache file?

See this For where numba looks for its cache.

You can customize it by setting the NUMBA_CACHE_DIR environment variable.

In the past, I’ve done something like the following to control where numba puts its cache:

from numba.misc.appdirs import AppDirs
from numba import config

appdirs = AppDirs(appname="myapp", appauthor=False)
config.CACHE_DIR = appdirs.user_cache_dir
print(config.CACHE_DIR)

On Linux this goes to /home/<user>/.cache/myapp, but in principle this approach should work for any OS, which maybe is simpler than using an environment variable.

Will this work in Python single .exe files?

No clue, but jitted code tends to compile down to binaries that are highly optimized for your particular machine. They are not generic binaries that will necessarily execute properly on other machines, even ones with the same operating system. So you probably don’t want to copy the cache and bundle it with a single .exe. If you go that route, it would be best to let numba recompile the code for the client’s machine. There are plans to expand numba’s (ahead-of-time) AOT compilation abilities, which would emit more generic binaries.

1 Like

I have another question regarding caching in numba:
when a function is decorated with @njit(cache=True), and that function on execution calls multiple other functions that are decorated with just @njit but not with the cache=True option, will the cache for the main calling function also cache those other functions or will numba always compile those sub functions on the fly?

In my testing, it seems that numba creates a dedicated .nbi and .nbc file for every function that is decorated with @njit(cache=True), but if I only use the cache decorator on the main function (that calls sub functions), then the cache improved execution time is the same as decorating all sub functions with @njit(cache=True)

in other words, if above is true then one could avoid creating unnecessary cache files for sub functions by not using the cache option on them.

Thanks!

I don’t have as much insight into the nitty gritty details of the compiler as the devs, but my understanding from my experience is that usually every individual jitted function will recompile its whole dependency tree… at least the jitted dependencies are recompiled. There are some exceptions… like I believe typed.List and typed.Dict have precompiled endpoints. What I’ve read on these forums is that the motivation for mostly recompiling everything instead of compiling individual functions and cross-linking them is that the LLVM compiler can be a bit more aggressive with optimizations when each jitted endpoint is treated as a whole program. I’m not sure I completely believe that argument since Link-Time Optimization (LTO) is a thing. But it sounds like some of the new AOT stuff (the devs are calling it PIXIE) strikes more of a balance in that regard.

The short answer is that numba will only compile what it needs, and do that compilation on the first execution (or retrieve it from the cache). So there isn’t any penalty to adding cache=True to functions that you don’t use directly as endpoints.

1 Like

thanks for the insight. much appreciated!

So there isn’t any penalty to adding cache=True to functions that you don’t use directly as endpoints

agreed, with the small exception that one could argue a slight downside of decorating every function with cache=True is the unnecessary additional disk space those cache files will occupy, if those functions are only ever called from other functions hence are already cached as dependencies there

it’s pretty useful though that we can pick what we want as a standalone disk cache file!