Any numba equivalent for casting a raw pointer to a StructRef, Dict, List etc?

Dear @DannyWeitekamp

First of all, thank you for sharing so much of your experience. That’s very valuable knowledge to me and likey many others too.
I have used a similar implementation to store a set of callbacks in an array to then call them repeatedly in a loop. And from a performance standpoint, there is a difference between retrieving a function from its address and passing the function directly. The main differences I noticed were:

  • Calling a jitted function that takes another jitted function as an argument from Python is quite expensive. In most cases, one avoids such repeated calls, but I mention it for completeness.
  • The function call overhead of both methods is I think different and almost every optimization is omitted when passing the address. In my case, where I called these small callbacks about a million times, the performance loss was significant.
  • In the end, I used the function address version because otherwise the caching would not have worked properly.

Below is a very simple example that demonstrates my points.

import numpy as np
import numba as nb
from numba import types 
from numba.core import cgutils
from numba.extending import intrinsic
from numba.experimental.function_type import _get_wrapper_address
# warm up numba
nb.njit(lambda: None)()

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    ...

@nb.njit(cache=True, inline='never')
def callback1():
    np.random.rand()

@nb.njit(cache=True, inline='never')
def callback2():
    np.random.rand()

callback_sig = types.void()
callback_type = types.FunctionType(callback_sig)
callback2_addr = _get_wrapper_address(callback2, callback_sig)

@nb.njit(cache=True)
def takes_function(func, repeats=1):
    for _ in range(repeats):
        func()
    
@nb.njit(cache=True)
def takes_address(func_addr, repeats=1):
    func = _func_from_address(callback_type, func_addr)
    for _ in range(repeats):
        func()

%timeit -n 1 -r 1 takes_function(callback1)
%timeit -n 1 -r 1 takes_address(callback2_addr)
# Second run in a Notebook after restarting the Kernel
668 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
10.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

%timeit takes_function(callback1)
%timeit takes_address(callback2_addr)
9.21 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
205 ns ± 7.69 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

%timeit takes_function(callback1, repeats=100_000)
%timeit takes_address(callback2_addr, repeats=100_000)
407 µs ± 35.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
782 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)