Tips or tricks for speeding up compilation time on first call of large Numba-jitted NumPy-containing functions?

I’m using Numba in a project to produce compiled functions at runtime. The functions in question are determined based on analysis of a user-defined directed acyclic graph of mathematical expressions. As such the number of input arguments and lines of code in the body of the function are variable and may be up to O(103) and O(105) respectively.

What I’m finding is that for functions with large bodies and numbers of arguments, the time taken when calling the Numba-jitted function for the first time is causing a bottleneck.

The code snippet below is a contrived example but gives the gist of what I’m doing in my project and illustrates the exponentially increasing compilation times. The table below contains some timings I ran locally. As the number of lines in the function body becomes O(103) or more then Numba’s compilation time goes from being the bottleneck to being prohibitively expensive. For reference the Numba-jitted callables execute ~100 times faster than a pure Python/NumPy function.

NUM_ARGS NUM_TANH Compilation time / s
102 102 0.56
102 103 4.9
102 104 360
103 102 2.9
103 103 8.2
103 104 450

Does anyone have an tips or tricks for speeding up the compilation stage for large Numba-jitted functions? Have I made a rookie mistake here or am I going about this the wrong way?

I know using Numba in this fashion on very large mathematical expressions might not be its primary intended use case, but this approach still massively outperforms using other tools such as Theano’s C code generators so I’m curious if the Numba-based approach can be improved further.

Note: I know the types of all of the input arguments, they will either be doubles or 1-dimensional np.arrays of dtype=np.float64.

import timeit

import numba as nb
import numpy as np

NUM_PROFILE_CALLS = 10000
NUM_PROFILE_REPEATS = 5
NUM_TANH = 1000 # adjust this to see how compilation time performance scales
NUM_ARGS = 1000 # adjust this to see how compilation time performance scales

def numbafy():
	"""Produce a numba-jitted function dynamically from a string.

	Allows me to mess around with different numbers of args and lines of 
	expressions within the jitted function to see how compilation time and 
	execution time are affected.
	"""

	function_name = "numbafied_function"

	# make a str like "x0, x1, x2, ..."
	arg_syms = [f"x{i}" for i in range(NUM_ARGS)]
	args = ", ".join(arg_syms)

	# make some arbitrary intermediate calculations which use all args for 
	# demonstration purposes
	intermediate_syms = [f"w{i}" for i in range(NUM_TANH+1)]
	w0 = "*".join(arg_syms)
	intermediate_exprs = [f"{wj} = np.tanh({wi})" 
		for wi, wj in zip(intermediate_syms[:-1], intermediate_syms[1:])]
	intermediates = "\n    ".join(intermediate_exprs)

	return_val = str(intermediate_syms[-1])

	# put it all together as a string and exec()
	func_str = (
		f"@nb.njit\n"
		f"def {function_name}({args}):\n"
		f"    w0 = {w0}\n"
		f"    {intermediates}\n"
		f"    return {return_val}\n")
	exec(func_str)
	compiled_func = locals()[function_name]

	return compiled_func

# time how long it takes to "numbafy" the function
numbafy_start_time = timeit.default_timer()
numbafied_function = numbafy()
numbafy_stop_time = timeit.default_timer()
print(f"Numbafy time: {numbafy_stop_time - numbafy_start_time}")

# generate some random test data
test_data = np.random.random(NUM_ARGS).tolist()

# time first call to numbafied function to time compilation time
compilation_start_time = timeit.default_timer()
_ = numbafied_function(*test_data)
compilation_stop_time = timeit.default_timer()
print(f"Compilation time: {compilation_stop_time - compilation_start_time}")

# profile execution time
profile_output = timeit.repeat(lambda: numbafied_function(*test_data), 
	number=NUM_PROFILE_CALLS, repeat=NUM_PROFILE_REPEATS)
mean_execution_time = (np.mean(profile_output) / (NUM_PROFILE_CALLS))
print(f"Execution time: {mean_execution_time}")```

hi @brocksam, I also use exec to write functions like these and have seen this behaviour. Regarding the input variables, you don’t need to pass your variables individually as single arguments. You could put all of them into a array. That’d be far more efficient from every point of view.

Regarding the function body, writing individual lines works well until you have about 100 lines. After that, what I do lately is to write “in text” loops.

	func_str = (
		f"@nb.njit\n"
		f"def {function_name}(x):\n"
		f"    w = np.prod(x)\n"
		f"    for _ in range({NUM_TANH+1}):\n"
		f"        w = np.tanh(w)\n"
		f"    return w\n")

Results

Numbafy time: 0.0005314410082064569
Compilation time: 0.06545790802920237
Execution time: 5.542060061125085e-06

I also get carried away by the text approach, and need to remind myself to go back to basics.

Hope this helps!

1 Like

Thank you @luk-f-a, that’s really helpful input! I had been using individual arguments and not a single array as you suggest because of Numba’s requirement that passed lists must be homogeneous. Although thanks to your suggestion I’ve realised I can partition my arguments into a much smaller number of homogeneous iterables to massively reduce the number of arguments passed. And then just index in to the arguments rather than unpacking to individual variables.

I perhaps made my contrived example a bit simplified. In actuality the function body will contain statements with no real structure. There will be a variety of different mathematical operations with the operands being arbitrary combinations of both the function arguments and previously computed intermediate variables - so something like:

func_str = (
	f"@nb.njit\n"
	f"def {function_name}(x):\n"
	f"    w0 = x[0] * x[1]"
	f"    w1 = np.sin(x[2])"
	...
	f"    w174 = x[54] + w111 + w150"
	f"    w175 = np.exp(w174)"
	...
	f"    return np.array([w1744, w1745, w1746])")

In this case I can’t see how your suggestion of “in-text” loops could easily be applied. Or am I missing something here?!

Do you know what it is that causes the performance drop in compilation time when the function body grows over ~100 lines in length? Compilation time seems to be pretty constant up to ~100 lines but then begins to increase exponentially beyond that. Given that I know the type of the function arguments as well as the types of each intermediate calculation in the function body I wonder if there is an optimisation to drastically improve the compilation time.

No, I don’t know. Something in the analysis Numba does, somehow does not scale well with the number of lines.

I can think of a trick: split your long function into shorter ones. You are in full control of the number of lines and the content, so it should be possible to break after 100 lines, figure out (symbolically) what the output should be (something like return w3, w174, x), and then create the signature of the next function to match that output, ie def f2(w3, w174, x):.
After that, calculate the full result as f10(f9(....(f2(f1(x))))))) for a total one 1000 lines split in 10 function on 100 lines.