Cache of njitted function that take other functions as arguments

We recently released numbakit-ode, a package that leverages numba to speed up ODE integration. Basically, we rewrote SciPy code in a numba compatible manner and got a 10X performance boost in integration each step. Multiple steps are performed in tight loop resulting in an even better performance.

You can see the project here.

But we still have a large overhead due to compilation of the integrator stepper function that it would be nice to remove by caching the compilation results. We are not able to make that work and I have wrote this simple example to show the problem.

import numba as nb

@nb.njit
def func(t, y):
    return y


@nb.njit(cache=True)
def stepper(f, a, b):
    return 3 * f(a, b)


print(stepper(func, 4., 2.))

which results in

Traceback (most recent call last):
[...]
    data_name = overloads[key]
KeyError: [...]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
[...]
    return pickle.dumps(obj, protocol=-1)
TypeError: cannot pickle 'weakref' object

My intuition (albeit without enough knowledge of the compilation process) is that this could work as compiling stepper only requires the signature of func, which should be cacheable. I wonder if this is just current limitation of Numba or there is a fundamental reason why this will not work.

1 Like

I am unable to reproduce this issue on OSX with current master at 0.53.0dev0-112-gb369047a92. Perhaps you could attach the output of numba -s? Thanks!

Hi @hgrecco,

https://github.com/numba/numba/pull/6373 fixed this, it’ll be in 0.53.0. Something along the lines of https://github.com/numba/numba/pull/6284 is also needed to make the cache work correctly, aiming to get that in for 0.53.0 too. For reference 0.53.0 is scheduled for early Q1 2021. Development builds are available as noted here: https://numba.readthedocs.io/en/latest/user/faq.html#how-do-i-get-numba-development-builds

Hope this helps.

3 Likes

Hi @stuartarchibald,

This is awesome, thanks for the information and the work on this issue. I will start testing that version and report back if I find any problem.

Loving numbakit-ode @hgrecco! So happy to see you started with DOP853 :heart_eyes: we will be using it in poliastro very soon! https://github.com/poliastro/poliastro/issues/1042

@astrojuanlu happy to hear this!. On the weekend we will release version 0.3 which settles the public API. We are looking forward to seeing numbakit-ode being use.

Out of curiosity: @hgrecco are you finding performance issues like the one I described in https://github.com/numba/numba/issues/2952 ? Or are the objective functions you’re testing always the bottleneck?

We have observed this behavior, but in the typical calculations we do using numba is better than not using it.

Having said that, it would be great to have this issue fixed as it will provide an extra and most welcomed performance boost.

We have added an extended version of your first code in #2952 to our benchmarks:

· Discovering benchmarks
· Running 2 total benchmarks (1 commits * 1 environments * 2 benchmarks)
[  0.00%] ·· Benchmarking existing-py_Users_grecco_anaconda3_envs_sci38_bin_python
[ 25.00%] ··· Running (nbcompat.Suite.time_newton--)..
[ 75.00%] ··· nbcompat.Suite.time_newton                                                                                                              2/6 failed
[ 75.00%] ··· ========= ========== ==========
              --                numba
              --------- ---------------------
               variant     True      False
              ========= ========== ==========
                scipy    221±70μs   213±20μs
                simple    failed     failed
                nbkode   68.4±8μs   20.5±1μs
              ========= ========== ==========

[100.00%] ··· nbcompat.Suite.time_newton_fprime                                                                                                               ok
[100.00%] ··· ========= ========== ============
              --                 numba
              --------- -----------------------
               variant     True       False
              ========= ========== ============
                scipy    182±40μs    160±10μs
                simple   16.7±5μs   2.61±0.1μs
                nbkode   132±50μs   7.78±0.6μs
              ========= ========== ============

where variant is:

  • scipy is SciPy implementation of Newton methods
  • simple the one included in your code (which does only works with fprime)
  • nbkode Newton included in numbakit-ode is a njit compatible implementation of the newton method that follows closely the one included in scipy.

and numba:

  • True indicates that the njitted version of all functions are used (except SciPy newton)
  • False indicates that the py_func version of all functions used

The code is here: https://github.com/hgrecco/numbakit-ode/blob/main/benchmarks/nbcompat.py

Just to double check, the issue being discussed here is essentially the performance of dispatch, as per: https://github.com/numba/numba/issues/2952#issuecomment-493367514 ?

@stuartarchibald Indeed, @astrojuanlu expanded the original scope of this post to include issue 2952 about dispatching (which I was unaware about the discussion in github).

I think that it makes sense as there are a few things that hit the performance of functions that take other functions as arguments, and therefore conspire against a wider adoption of numba in libraries that, like numbakit-ode, aim to provide generic algorithms.

Until now, I have identified:

  • Lack of cache for compiled functions that take other functions
  • Lack of function pointers, i.e. that only the signature matters
  • Lack of similar dispatching performance when arguments are functions.