Dear @DannyWeitekamp
First of all, thank you for sharing so much of your experience. That’s very valuable knowledge to me and likey many others too.
I have used a similar implementation to store a set of callbacks in an array to then call them repeatedly in a loop. And from a performance standpoint, there is a difference between retrieving a function from its address and passing the function directly. The main differences I noticed were:
- Calling a jitted function that takes another jitted function as an argument from Python is quite expensive. In most cases, one avoids such repeated calls, but I mention it for completeness.
- The function call overhead of both methods is I think different and almost every optimization is omitted when passing the address. In my case, where I called these small callbacks about a million times, the performance loss was significant.
- In the end, I used the function address version because otherwise the caching would not have worked properly.
Below is a very simple example that demonstrates my points.
import numpy as np
import numba as nb
from numba import types
from numba.core import cgutils
from numba.extending import intrinsic
from numba.experimental.function_type import _get_wrapper_address
# warm up numba
nb.njit(lambda: None)()
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
...
@nb.njit(cache=True, inline='never')
def callback1():
np.random.rand()
@nb.njit(cache=True, inline='never')
def callback2():
np.random.rand()
callback_sig = types.void()
callback_type = types.FunctionType(callback_sig)
callback2_addr = _get_wrapper_address(callback2, callback_sig)
@nb.njit(cache=True)
def takes_function(func, repeats=1):
for _ in range(repeats):
func()
@nb.njit(cache=True)
def takes_address(func_addr, repeats=1):
func = _func_from_address(callback_type, func_addr)
for _ in range(repeats):
func()
%timeit -n 1 -r 1 takes_function(callback1)
%timeit -n 1 -r 1 takes_address(callback2_addr)
# Second run in a Notebook after restarting the Kernel
668 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
10.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
%timeit takes_function(callback1)
%timeit takes_address(callback2_addr)
9.21 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
205 ns ± 7.69 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit takes_function(callback1, repeats=100_000)
%timeit takes_address(callback2_addr, repeats=100_000)
407 µs ± 35.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
782 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)