Hi. I’ve been struggling to come up with a clean solution for a problem I have and was hoping someone could give me some advice, I’ve looked around for some possible implementations (like jitting functools.partial
functions) but none seem to work.
I have defined a structure that is like a list of arrays but is actually two 1D arrays (kind of like sparse matrices but without the indices, just data and indptr) and a numba jitted function that takes the two arrays (data and indptr), a function and the function arguments (this arguments part is what I’m trying to remove) and applies it to every “element” of that structure. Here’s a snippet:
@njit
def apply_func_to_groups(data, indptr, func, *args):
ngroups = len(indptr) - 1
result = np.empty_like(data)
for i in range(ngroups):
group_slice = slice(indptr[i], indptr[i+1])
result[group_slice] = func(data[group_slice], *args)
return result
The drawback here is that this gets compiled for functions that take no arguments, functions that take 1 argument, 2 arguments, etc. I’d like to be able to just pass a function with already set default arguments (like what functools.partial
does). Is there a way to achieve it?
hi @jose-moralez , I think it’s an interesting question and I can probably help you having spent some time looking compilation times.
The first thing that I would like to mention is that avoiding re-compilation is harder that it sounds. So it’s important to think if you really need it. I measured the compilation time on my computer, and it was 0.13 seconds. Is this really a problem?
If you really want to go down the way of avoiding recompilation, then by looking at the function I can tell you that the biggest issue is not args
but func
. Yes, args
generates a recompilation for each number of arguments, but func
generates a recompilation for each different function. Even if you used some kind of partial (and there are some ways to achieve it), every different args
would generate a different partial, and each of those has to be recompiled separately, and then, by being different functions, forces apply_func_to_groups
to be recompiled.
I’ll show with an example:
def foo(x):
return x
foo(1) # first compilation
foo(2) # no compilation
apply_func_to_groups(...., foo,1) # compiles apply_func_to_groups
apply_func_to_groups(...., foo,2) # no compilation
# using partial
def foo_1():
return foo(1)
def foo_2():
return foo(2)
foo_1() # compiles foo and foo_1
foo_2() # compiles foo_2
apply_func_to_groups(...., foo_1) # compiles apply_func_to_groups
apply_func_to_groups(...., foo_2) # compiles apply_func_to_groups again!
The use of partial has increased the number of compilations from 2 to 5.
So, in summary, avoiding re-compilation is hard and requires understanding Numba’s type system in detail. Sometimes avoiding re-compilation is possible and sometimes is not (I learned all this working on an application and 1 year later I still cannot avoid recompilation). Even when it is possible, you might have to make other compromises so the solution will be have to be tailored to your specific problem and will have to make assumptions about how many different func
you expect to have, how many different args
you expect to have, etc.
hope this helps!
Luk
Thank you @luk-f-a! You’re right, the problem is having many different functions, not the arguments. Each function takes around 300ms to compile by itself and then plugging it into the apply_func_to_groups
takes another 400ms. I was looking at like 5 seconds of compilation and thought it was the arguments but actually each different call to apply_func_to_groups
takes around 700ms to compile so using 6 different functions already makes up for about 4 seconds.
Thanks for making this clear, that definitely helped me a lot!