Do you happen to have any thoughts or observations regarding the performance implications (if any) of calling/using first-class functions vs casting them?
Hey @nelson2005. Its certainly worth running a benchmark, but my experience has been that recovering a first class function from a pointer is very fast, perhaps almost negligible. I can’t speak to how this compares relative to passing them around as FunctionType objects in a typed List or something like that. But I suspect it may be slightly faster just because typed Lists/Dicts are currently implemented in a way that prevents getitem from being inlined properly, and with an int64 pointer you can use numpy arrays which are very well optimized.
To speak to why recovering frist-class functions is fast, even negligible. I’d have to poke around the numba source to be sure, but I’m pretty sure that inside jitted code first-class functions aren’t dynamically allocated objects—they’re just structs that hold the function address. I suspect that once everything gets compiled down to assembly you wouldn’t even be able to point to a section associated with actually constructing the FunctionType object. This is a hunch, but I suspect passing a first-class function object and passing an int64 would ultimate boil down to the same thing—also a testable hypothesis if you’re inclined to test it.
As for the cost of calling the function. I strongly doubt there would be any differences between the two methods. Although the overhead of how you get to that point might be substantial enough to worry about if you’re using Lists() at the moment. I’ve been completely satisfied with the runtime performance of using my _func_from_address intrinsic in performance heavy sections—it doesn’t seem to contribute any overhead.
knowing that trick a while ago would have saved me a lot of pain
Seems an example is in order then:
_func_from_address and other helpful intrinsics defined here
Usage:
from numba import njit, types, f8
from numba.experimental.function_type import _get_wrapper_address
from numba.core import cgutils
from numba.extending import intrinsic
# --------------------------------
# : _func_from_address implementation
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
'''Recovers a function from FunctionType and address '''
func_type = func_type_ref.instance_type
def codegen(context, builder, sig, args):
_, addr = args
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
llty = context.get_value_type(types.voidptr)
addr_ptr = builder.inttoptr(addr,llty)
sfunc.addr = addr_ptr
return sfunc._getvalue()
sig = func_type(func_type_ref, addr)
return sig, codegen
# -----------------------------------
# : example
@njit
def foo(a,b):
return a + b
my_fn_type = types.FunctionType(f8(f8,f8))
foo_addr = _get_wrapper_address(foo, f8(f8,f8))
@njit
def apply_whatever(addr, a, b):
f = _func_from_address(my_fn_type, addr)
return f(a,b)
print("Should be 4.0: is ", apply_whatever(foo_addr, 1,3))
Should be 4.0: is 4.0
Dear @DannyWeitekamp
First of all, thank you for sharing so much of your experience. That’s very valuable knowledge to me and likey many others too.
I have used a similar implementation to store a set of callbacks in an array to then call them repeatedly in a loop. And from a performance standpoint, there is a difference between retrieving a function from its address and passing the function directly. The main differences I noticed were:
- Calling a jitted function that takes another jitted function as an argument from Python is quite expensive. In most cases, one avoids such repeated calls, but I mention it for completeness.
- The function call overhead of both methods is I think different and almost every optimization is omitted when passing the address. In my case, where I called these small callbacks about a million times, the performance loss was significant.
- In the end, I used the function address version because otherwise the caching would not have worked properly.
Below is a very simple example that demonstrates my points.
import numpy as np
import numba as nb
from numba import types
from numba.core import cgutils
from numba.extending import intrinsic
from numba.experimental.function_type import _get_wrapper_address
# warm up numba
nb.njit(lambda: None)()
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
...
@nb.njit(cache=True, inline='never')
def callback1():
np.random.rand()
@nb.njit(cache=True, inline='never')
def callback2():
np.random.rand()
callback_sig = types.void()
callback_type = types.FunctionType(callback_sig)
callback2_addr = _get_wrapper_address(callback2, callback_sig)
@nb.njit(cache=True)
def takes_function(func, repeats=1):
for _ in range(repeats):
func()
@nb.njit(cache=True)
def takes_address(func_addr, repeats=1):
func = _func_from_address(callback_type, func_addr)
for _ in range(repeats):
func()
%timeit -n 1 -r 1 takes_function(callback1)
%timeit -n 1 -r 1 takes_address(callback2_addr)
# Second run in a Notebook after restarting the Kernel
668 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
10.9 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
%timeit takes_function(callback1)
%timeit takes_address(callback2_addr)
9.21 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
205 ns ± 7.69 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%timeit takes_function(callback1, repeats=100_000)
%timeit takes_address(callback2_addr, repeats=100_000)
407 µs ± 35.2 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
782 µs ± 19.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
@sschaer Great points. You’re last two benchmarks seem to prove my hypothesis wrong. As for you other points I’ve noticed similar things. As you say numba doesn’t seem to use all the possible optimizations when loading the function from an address. In my experience, directly calling a jitted function in another is always fastest, probably partly because of inlining (which you control for in your tests), and there’s really no getting around that of course when using first-class functions. But, I wonder if with some digging if there is some way to explicitly turn on or off some flags that would force the compiler to do the same optimizations in both the second and last benchmark you do here.
For instance, in principle if a function dispatcher is passed as a argument the compiler has a chance to look at it’s source and decide whether it has side-effects, in which case, some flag might be turned that allows for loop lifting/vectorization. If that is indeed what is happening here then it is probably also possible for someone with hard-core performance needs to make that decision themself, and in select cases switch those flags on manually inside the intrinsic. I have no idea if that’s possible, but maybe worth looking around for.
Okay so indeed it does seem to be possible to force the compiler to do some more optimizations over first-class function calls in a loop. The trick is that you have to add the attributes to the function call-site instead of to the function object as stated here: code generation - LLVM: Setting function attributes on a pointer - Stack Overflow
So now consider this example with a new _call_fast intrinsic which annotates the call instruction with ‘nounwind’ and ‘readnone’.
import numpy as np
import numba as nb
from numba import types, f8
from numba.core import cgutils
from numba.extending import intrinsic
from numba.experimental.function_type import _get_wrapper_address
from cre.utils import PrintElapse
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
....
# --------------------------------
# : _call_fast implementation
@intrinsic
def _call_fast(typingctx, func, args):
'''Calls a FunctionType with 'nounwind' and 'readnone' set '''
func_type = func
def codegen(context, builder, sig, _args):
_, inp_types = sig.args
arg_types = func_type.signature.args
func, args = _args
# Unpack args and cast types to the types specified in the func signature
args = cgutils.unpack_tuple(builder, args, len(inp_types))
args = [context.cast(builder, a, it, at) for a, it, at in zip(args, inp_types, arg_types)]
# Grab the function address
sfunc = cgutils.create_struct_proxy(func_type)(context, builder, func)
llty = context.get_value_type(func_type.ftype)
fn_addr = builder.bitcast(sfunc.addr, llty)
# Call the function with special attributes
ret = builder.call(fn_addr, args, cconv=func_type.cconv, attrs=("nounwind",'readnone'))
return ret
sig = func.signature.return_type(func, args)
return sig, codegen
@nb.njit(cache=True, inline='never')
def callback1():
return np.random.rand()
@nb.njit(cache=True, inline='never')
def callback2():
return np.random.rand()
callback_sig = f8()
callback_type = types.FunctionType(callback_sig)
callback2_addr = _get_wrapper_address(callback2, callback_sig)
@nb.njit(cache=True)
def takes_function(func, repeats=1):
z = 0
for _ in range(repeats):
z += func()
return z
@nb.njit(cache=True)
def takes_address(func_addr, repeats=1):
z = 0
func = _func_from_address(callback_type, func_addr)
for _ in range(repeats):
z += func()
return z
@nb.njit(cache=False)
def takes_address_fast(func_addr, repeats=1):
z = 0
func = _func_from_address(callback_type, func_addr)
for _ in range(repeats):
z += _call_fast(func,())
return z
takes_function(callback1, repeats=100000)
with PrintElapse("takes_function"):
takes_function(callback1, repeats=100000)
takes_address(callback2_addr, repeats=100000)
with PrintElapse("takes_address"):
takes_address(callback2_addr, repeats=100000)
takes_address_fast(callback2_addr, repeats=100000)
with PrintElapse("takes_address_fast"):
takes_address_fast(callback2_addr, repeats=100000)
Noting one difference from your benchmarks @sschaer—we need the calling function to do something with the return value or the compiler will prune it. In any case the speedup in this test seems considerable.
takes_function: 0.40 ms
takes_address: 1.34 ms
takes_address_fast: 0.09 ms
Now of course one should use something like this with caution. Not all functions can safely be annotated this way.
The cases where this helps seem situational. Below is with the callbacks set to return a+b with signature f8(f8,f8)
takes_function: 0.12 ms
takes_address: 0.34 ms
takes_address_fast: 0.35 ms
Interesting- any thoughts about what might affect that situationality?
No clue. Probably one of those things that requires digging into the LLVM or assembly to understand better.