I’m using Numba in a project to produce compiled functions at runtime. The functions in question are determined based on analysis of a user-defined directed acyclic graph of mathematical expressions. As such the number of input arguments and lines of code in the body of the function are variable and may be up to O(103) and O(105) respectively.
What I’m finding is that for functions with large bodies and numbers of arguments, the time taken when calling the Numba-jitted function for the first time is causing a bottleneck.
The code snippet below is a contrived example but gives the gist of what I’m doing in my project and illustrates the exponentially increasing compilation times. The table below contains some timings I ran locally. As the number of lines in the function body becomes O(103) or more then Numba’s compilation time goes from being the bottleneck to being prohibitively expensive. For reference the Numba-jitted callables execute ~100 times faster than a pure Python/NumPy function.
NUM_ARGS |
NUM_TANH |
Compilation time / s |
---|---|---|
102 | 102 | 0.56 |
102 | 103 | 4.9 |
102 | 104 | 360 |
103 | 102 | 2.9 |
103 | 103 | 8.2 |
103 | 104 | 450 |
Does anyone have an tips or tricks for speeding up the compilation stage for large Numba-jitted functions? Have I made a rookie mistake here or am I going about this the wrong way?
I know using Numba in this fashion on very large mathematical expressions might not be its primary intended use case, but this approach still massively outperforms using other tools such as Theano’s C code generators so I’m curious if the Numba-based approach can be improved further.
Note: I know the types of all of the input arguments, they will either be double
s or 1-dimensional np.array
s of dtype=np.float64
.
import timeit
import numba as nb
import numpy as np
NUM_PROFILE_CALLS = 10000
NUM_PROFILE_REPEATS = 5
NUM_TANH = 1000 # adjust this to see how compilation time performance scales
NUM_ARGS = 1000 # adjust this to see how compilation time performance scales
def numbafy():
"""Produce a numba-jitted function dynamically from a string.
Allows me to mess around with different numbers of args and lines of
expressions within the jitted function to see how compilation time and
execution time are affected.
"""
function_name = "numbafied_function"
# make a str like "x0, x1, x2, ..."
arg_syms = [f"x{i}" for i in range(NUM_ARGS)]
args = ", ".join(arg_syms)
# make some arbitrary intermediate calculations which use all args for
# demonstration purposes
intermediate_syms = [f"w{i}" for i in range(NUM_TANH+1)]
w0 = "*".join(arg_syms)
intermediate_exprs = [f"{wj} = np.tanh({wi})"
for wi, wj in zip(intermediate_syms[:-1], intermediate_syms[1:])]
intermediates = "\n ".join(intermediate_exprs)
return_val = str(intermediate_syms[-1])
# put it all together as a string and exec()
func_str = (
f"@nb.njit\n"
f"def {function_name}({args}):\n"
f" w0 = {w0}\n"
f" {intermediates}\n"
f" return {return_val}\n")
exec(func_str)
compiled_func = locals()[function_name]
return compiled_func
# time how long it takes to "numbafy" the function
numbafy_start_time = timeit.default_timer()
numbafied_function = numbafy()
numbafy_stop_time = timeit.default_timer()
print(f"Numbafy time: {numbafy_stop_time - numbafy_start_time}")
# generate some random test data
test_data = np.random.random(NUM_ARGS).tolist()
# time first call to numbafied function to time compilation time
compilation_start_time = timeit.default_timer()
_ = numbafied_function(*test_data)
compilation_stop_time = timeit.default_timer()
print(f"Compilation time: {compilation_stop_time - compilation_start_time}")
# profile execution time
profile_output = timeit.repeat(lambda: numbafied_function(*test_data),
number=NUM_PROFILE_CALLS, repeat=NUM_PROFILE_REPEATS)
mean_execution_time = (np.mean(profile_output) / (NUM_PROFILE_CALLS))
print(f"Execution time: {mean_execution_time}")```