After reading @stuartarchibald replies and reviewing the low-level code with @gmarkall a little bit about this, it seemed like we could accomplish NumPy performance for ufunc.reduce
-like operations within Numba if we add some user-configurable flags for the optimization passes in Numba (e.g. the lines that @stuartarchibald changed).
Just to clarify and recap the entire situation, when Aesara is asked to convert a graph representing np.max(x, y, axis=1)
to a Numba-njit
ed function, it does so piece-by-piece. That starts with a scalar max
function, for which we can generate a custom vectorized function like the following (well, at least after Add support for `np.broadcast_to` by guilhermeleobas · Pull Request #7119 · numba/numba · GitHub goes through):
import numpy as np
import numba
@numba.njit
def vectorized_max(x, y, out=None):
if out is None:
out = np.empty((x.shape[0],), dtype=np.float64)
for i in range(out.shape[0]):
if x[i] > y[i]:
out[i] = x[i]
else:
out[i] = y[i]
return out
With this function, Aesara can then implement the axis=1
part using something like the following:
@numba.njit
def max_reduce_axis_1(x):
x_transpose = np.transpose(x)
res = np.full((x.shape[0]), -np.inf, dtype=np.float64)
for m in range(x.shape[1]):
vectorized_max(res, x_transpose[m], res)
return res
It seems like the resulting max_reduce_axis_1
can be currently optimized to the same degree as @stuartarchibald’s all-in-one example. Apparently, all that’s left is to get the extra SIMD-related optimizations that were enabled by setting loop_vectorize=True
, slp_vectorize=True
, and possibly opt=3
in the “cheap” optimization pass.
Again, a quick fix might involve additional user-configurable options that allow the adjustment in the “cheap” pass; however, since we definitely don’t want to increase the overall compile time when the Numba backend is used (especially when the plan is to make it the default backend), it would be best if this option could be enabled only for the compilation of these specific ufunc.reduce
-like functions.
I’m not sure if that’s possible via numba.config
options (e.g. if we temporarily set those options, force immediate njit
compilation of the function and unset the options), but, if it is, I would willing to put in a PR for this.