Jit recompile with new arguments

As usual, thanks very much for making Numba!

There is a recompile() function available on jitted functions. It would be great if it allowed the user to change the original parameters from when the function was first jitted.

In particular, I would like to be able to switch between parallel=True and parallel=False in functions that I import from external Python packages and that have been jitted one way or the other by default.

There is a hack for doing this, as explained in post #1125 (I cannot write links in my posts here). But it is a really hacky way of doing it, so I think it would be much nicer to be able to change the arguments when calling recompile() on the jitted functions.

Would this be possible to implement? Or is it too difficult and not worth the hassle?

@esc do you have an opinion on this? Thanks!

I’m not sure this is the best advice, but: every jitted function has an attribute: py_func. You can grab that to get the original Python function and then just “re-jit” it in whichever way you want (The example is a bit contrived, but you get the idea) :

In [1]: from numba import njit

In [2]: @njit
   ...: def fun(x):
   ...:     return x + 1
   ...:

In [3]: fun(1)
Out[3]: 2

In [5]: njf = njit(parallel=True)(fun.py_func)

In [6]: njf(2)
/Users/vhaenel/git/numba/numba/core/typed_passes.py:329: NumbaPerformanceWarning:
The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.

To find out why, try turning on parallel diagnostics, see https://numba.readthedocs.io/en/stable/user/parallel.html#diagnostics for help.

File "<ipython-input-2-28a5c020e544>", line 2:
@njit
def fun(x):
^

  warnings.warn(errors.NumbaPerformanceWarning(msg,
Out[6]: 3
1 Like

I am so sorry, I have just seen you discovered the py_func approach already, apologies for the noise in this case.

Thanks anyway! :slight_smile:

Let’s consider the question a bit more broadly than my particular use-case: Would it be a good idea to be able to change @jit arguments when recompiling a jitted function?

And maybe it shouldn’t be done in the recompile function itself, because it appears that function only becomes available once the underlying Python function has been jit-compiled once. So maybe there could be a function like change_args that would change the jit-arguments, and then the next call to recompile would use those arguments.

I think this would be very useful e.g. when importing jitted functions from an external Python package, and there is a need to change the jit arguments.

What do the Numba developers think?

Hi @Hvass-Labs

Let’s consider the question a bit more broadly than my particular use-case: Would it be a good idea to be able to change @jit arguments when recompiling a jitted function?

First, some general remarks RE changing compilation flags. I don’t think that this can be considered safe in general, though it might be possible to write in very careful and restricted way to permit this in some circumstances. An example of something unsafe happening might be a library with a function that is sensitive to order of operations or NaN handling, it would be potentially be invalid to recompile this with fastmath=True as it would break to original intent of the library.

Concrete example of this:

from numba import njit
import numpy as np


@njit(fastmath=False)
def nan_sensitive_foo(x):
    count = 0
    for item in x:
        if not np.isnan(item):
            count += 1
    return count

x = np.array([1, 2, np.nan, 4, 5])

print(f"Found {nan_sensitive_foo(x)} values that are not NaN")

# Recompile with fastmath

fastmath_nan_sensitive_foo = njit(fastmath=True)(nan_sensitive_foo.py_func)
print(f"Found {fastmath_nan_sensitive_foo(x)} values that are not NaN")

Whilst harder to demonstrate, similar issues potentially exist with parallel=True, particularly with regards to ending up with inadvertently nested parallel regions which could end up with poor performance. An example of triggering this might be:

from numba import njit, prange, get_num_threads
import numpy as np

def gen(do_par=None):

    @njit(parallel=do_par)
    def some_calc(n):
        # Array operations will be parallelized if parallel=True
        a = np.ones(n)
        b = a * 4
        return b.sum()

    @njit(parallel=True)
    def foo(n):
        acc = 0
        # explicitly parallel loop
        for i in prange(n):
            acc += some_calc(n)
        return acc

    return foo, some_calc

n = 5

# foo_false has a parallel prange calling a serial some_calc
foo_false, some_calc_false = gen(False)
foo_false(n)

print('-' * 80)
# foo_true has a parallel prange calling a parallel some_calc
# prange threads will potentially launch more threads
foo_true, some_calc_true = gen(True)
foo_true(n)

# Show the parallel loops
some_calc_true.parallel_diagnostics(level=2)
foo_true.parallel_diagnostics(level=2)

RE: .recompile() etc. The Numba dispatcher object (the thing returned by @jit) is configured directly from the @jit decorator options, the
configuration is not intended to be changed once the dispatcher object is constructed. The recompile() method exists to handle cases such as a global variable that a @jit function refers to changing (recall that Numba considers global variables as compile time constants, so if they change, need to recompile). Docs: Frequently Asked Questions — Numba 0+untagged.4124.gd4460fe.dirty documentation

If you really want the behavior suggested, Numba has jit_module (docs: Automatic module jitting with jit_module — Numba 0+untagged.4124.gd4460fe.dirty documentation). It might be possible to make use of this or some derivative of this to JIT compile an entire module with a set of flags that sometimes change. However, I suspect that this may well be quite brittle and create hard to debug situations.

Hope this helps.

Thanks very much for the detailed explanation!

I can see this would open up a can of worms that probably wouldn’t be worth the hassle for such rare use-cases.

I think I forgot to mention one important reason why I would like to switch between parallel and serial jit-versions. My particular problem needs to be quite large before the parallel version is faster, and for smaller problem-sizes the serial jit version is much faster. But I will have to find another elegant solution for switching between these.

The solution I ended up using is written here.