Override or set default arguments of jitted function

Hi. I’ve been struggling to come up with a clean solution for a problem I have and was hoping someone could give me some advice, I’ve looked around for some possible implementations (like jitting functools.partial functions) but none seem to work.

I have defined a structure that is like a list of arrays but is actually two 1D arrays (kind of like sparse matrices but without the indices, just data and indptr) and a numba jitted function that takes the two arrays (data and indptr), a function and the function arguments (this arguments part is what I’m trying to remove) and applies it to every “element” of that structure. Here’s a snippet:

@njit
def apply_func_to_groups(data, indptr, func, *args):
    ngroups = len(indptr) - 1
    result = np.empty_like(data)
    for i in range(ngroups):
        group_slice = slice(indptr[i], indptr[i+1])
        result[group_slice] = func(data[group_slice], *args)
    return result

The drawback here is that this gets compiled for functions that take no arguments, functions that take 1 argument, 2 arguments, etc. I’d like to be able to just pass a function with already set default arguments (like what functools.partial does). Is there a way to achieve it?

hi @jose-moralez , I think it’s an interesting question and I can probably help you having spent some time looking compilation times.
The first thing that I would like to mention is that avoiding re-compilation is harder that it sounds. So it’s important to think if you really need it. I measured the compilation time on my computer, and it was 0.13 seconds. Is this really a problem?

If you really want to go down the way of avoiding recompilation, then by looking at the function I can tell you that the biggest issue is not args but func. Yes, args generates a recompilation for each number of arguments, but func generates a recompilation for each different function. Even if you used some kind of partial (and there are some ways to achieve it), every different args would generate a different partial, and each of those has to be recompiled separately, and then, by being different functions, forces apply_func_to_groups to be recompiled.
I’ll show with an example:

def foo(x):
    return x

foo(1)  # first compilation
foo(2)  # no compilation
apply_func_to_groups(...., foo,1)  # compiles apply_func_to_groups
apply_func_to_groups(...., foo,2)  # no compilation


# using partial

def foo_1():
     return foo(1)

def foo_2():
    return foo(2)

foo_1()  # compiles foo and foo_1
foo_2()  # compiles foo_2
apply_func_to_groups(...., foo_1)  # compiles apply_func_to_groups
apply_func_to_groups(...., foo_2)  # compiles apply_func_to_groups again!

The use of partial has increased the number of compilations from 2 to 5.

So, in summary, avoiding re-compilation is hard and requires understanding Numba’s type system in detail. Sometimes avoiding re-compilation is possible and sometimes is not (I learned all this working on an application and 1 year later I still cannot avoid recompilation). Even when it is possible, you might have to make other compromises so the solution will be have to be tailored to your specific problem and will have to make assumptions about how many different func you expect to have, how many different args you expect to have, etc.

hope this helps!
Luk

Thank you @luk-f-a! You’re right, the problem is having many different functions, not the arguments. Each function takes around 300ms to compile by itself and then plugging it into the apply_func_to_groups takes another 400ms. I was looking at like 5 seconds of compilation and thought it was the arguments but actually each different call to apply_func_to_groups takes around 700ms to compile so using 6 different functions already makes up for about 4 seconds.

Thanks for making this clear, that definitely helped me a lot!