Using default args without sacrificing performance

Hi all. Below is an MRE of a confounding performance issue related to default argument values, which I stumbled upon when converting some old pure python funcs to njitted funcs. In my actual script, some pure python that calls an njitted func and slightly modifies its result is actually many times faster than a numba func with default args doing the call and the simple modification instead.

from numba import njit
from time import perf_counter
from random import randint

test_a, test_b = randint(9000,9999), randint(9000,9999)
extra_arg = 2  # used for a simple modification of the primary, intensive function's output

# As is typical, I care most about timings on the very first run after it has already 
# compiled with the right dtypes and been cached. I have equivalently compiled
# the func for this minimum reproducible example here via 'warm-up' calls.

# Test problematic NJIT
@njit
def heavytest_njit(x, y):
    total = 0
    for i in range(x):
        for j in range(y):
            total += x+y
    return total
@njit
def simplemod_njit(a, b, also=extra_arg):
    return heavytest_njit(a, b) * also
# Simply compile (no caching), with 3rd arg left empty to use default val.
heavytest_njit(10, 30)
simplemod_njit(11, 31)
# Time it
restart = perf_counter()
simplemod_njit(test_a, test_b)
print(f"1 NJIT func using global var as default arg took:\t {round((perf_counter()-restart)*1e6,1)} microseconds  <-- much slower, which surprised me")

# Test ideal NJIT...
@njit
def simplemod_njit2(a, b, also=extra_arg):
    return heavytest_njit(a, b) * also
# Simply compile (no caching)
simplemod_njit2(12, 32, 3)
# Time it
restart = perf_counter()
simplemod_njit2(test_a, test_b+1, extra_arg)
print(f"1 NJIT func with default arg overwritten took:\t\t {round((perf_counter()-restart)*1e6,1)} microseconds  <-- usually fastest njit option")

# Test ideal NJIT...
@njit
def simplemod_njit3(a, b, also):
    return heavytest_njit(a, b) * also
# Simply compile (no caching)
simplemod_njit3(13, 33, 4)
# Time it
restart = perf_counter()
simplemod_njit3(test_a+1, test_b, extra_arg)
print(f"1 NJIT func without default arg took:\t\t\t\t {round((perf_counter()-restart)*1e6,1)} microseconds")

# Test mixed NJIT, which by accident I found was very fast...
def simplemod_mixed(a, b, also=extra_arg):
    return heavytest_njit(a, b) * also
# Time it
restart = perf_counter()
simplemod_mixed(test_a-1, test_b)
print(f"1 Pure python call of NJIT func took:\t\t\t\t {round((perf_counter()-restart)*1e6,1)} microseconds")

# Test pure python version as yet another fun demo of numba's awesomeness.
def heavytest(x, y):
    total = 0
    for i in range(x):
        for j in range(y):
            total += x+y
    return total
def simplemod_pure(a, b, also=extra_arg):
    return heavytest(a, b) * also
# Time it
restart = perf_counter()
simplemod_pure(test_a, test_b-1)
print(f"1 Pure python equivalent func took:\t\t\t {round((perf_counter()-restart)*1e6,1)} microseconds")

Struggling to understanding this behavior here. Is it just that reading a global variable is very slow for numba? If so, is the second function version faster because that step intelligently gets skipped, since a third argument was provided? I was surprised too in this MRE that simply allowing an arg to be set by a global var is sometimes 100% slower than providing the arg when the njitted func is called. Is that expected behavior?

The njitted funcs are wonderfully fast on subsequent calls, but the very first call’s performance is important in my app too. Pretty confused really. Any help at all would be appreciated!

I did workaround the issue by providing the signatures (14 of them actually in my app, which is not ideal). I still would like to understand a bit about how numba uses global variables and default values for arguments in compilation.

What timings do you get? I see some minor differences, but it’s hard to get a clear signal above the standard deviation,certainly not the order of magnitude you mention. Only the fully pure Python version is several orders of magnitude slower.

Excluding compilation:

r1 = %timeit -o simplemod_njit(test_a, test_b)
# 395 ns ± 23.9 ns per loop (mean ± std. dev. of 20 runs, 1,000,000 loops each)

r2 = %timeit -o simplemod_njit2(test_a, test_b, also=extra_arg)
# 326 ns ± 14.4 ns per loop (mean ± std. dev. of 20 runs, 1,000,000 loops each)

r3 = %timeit -o simplemod_njit3(test_a, test_b, extra_arg)
# 322 ns ± 12.7 ns per loop (mean ± std. dev. of 20 runs, 1,000,000 loops each)

r4 = %timeit -o simplemod_mixed(test_a, test_b)
# 451 ns ± 20.4 ns per loop (mean ± std. dev. of 20 runs, 1,000,000 loops each)