Hi @Hvass-Labs
Let’s consider the question a bit more broadly than my particular use-case: Would it be a good idea to be able to change @jit arguments when recompiling a jitted function?
First, some general remarks RE changing compilation flags. I don’t think that this can be considered safe in general, though it might be possible to write in very careful and restricted way to permit this in some circumstances. An example of something unsafe happening might be a library with a function that is sensitive to order of operations or NaN handling, it would be potentially be invalid to recompile this with fastmath=True
as it would break to original intent of the library.
Concrete example of this:
from numba import njit
import numpy as np
@njit(fastmath=False)
def nan_sensitive_foo(x):
count = 0
for item in x:
if not np.isnan(item):
count += 1
return count
x = np.array([1, 2, np.nan, 4, 5])
print(f"Found {nan_sensitive_foo(x)} values that are not NaN")
# Recompile with fastmath
fastmath_nan_sensitive_foo = njit(fastmath=True)(nan_sensitive_foo.py_func)
print(f"Found {fastmath_nan_sensitive_foo(x)} values that are not NaN")
Whilst harder to demonstrate, similar issues potentially exist with parallel=True
, particularly with regards to ending up with inadvertently nested parallel regions which could end up with poor performance. An example of triggering this might be:
from numba import njit, prange, get_num_threads
import numpy as np
def gen(do_par=None):
@njit(parallel=do_par)
def some_calc(n):
# Array operations will be parallelized if parallel=True
a = np.ones(n)
b = a * 4
return b.sum()
@njit(parallel=True)
def foo(n):
acc = 0
# explicitly parallel loop
for i in prange(n):
acc += some_calc(n)
return acc
return foo, some_calc
n = 5
# foo_false has a parallel prange calling a serial some_calc
foo_false, some_calc_false = gen(False)
foo_false(n)
print('-' * 80)
# foo_true has a parallel prange calling a parallel some_calc
# prange threads will potentially launch more threads
foo_true, some_calc_true = gen(True)
foo_true(n)
# Show the parallel loops
some_calc_true.parallel_diagnostics(level=2)
foo_true.parallel_diagnostics(level=2)
RE: .recompile()
etc. The Numba dispatcher object (the thing returned by @jit
) is configured directly from the @jit
decorator options, the
configuration is not intended to be changed once the dispatcher object is constructed. The recompile()
method exists to handle cases such as a global variable that a @jit
function refers to changing (recall that Numba considers global variables as compile time constants, so if they change, need to recompile). Docs: Frequently Asked Questions — Numba 0+untagged.4124.gd4460fe.dirty documentation
If you really want the behavior suggested, Numba has jit_module
(docs: Automatic module jitting with jit_module — Numba 0+untagged.4124.gd4460fe.dirty documentation). It might be possible to make use of this or some derivative of this to JIT compile an entire module with a set of flags that sometimes change. However, I suspect that this may well be quite brittle and create hard to debug situations.
Hope this helps.