@sklam, @stuartarchibald, @gmarkall, per our conversation today, here’s an updated script with all the Numba implementations we’ve tried for np.max.reduce(..., axis=1)
:
from contextlib import contextmanager
import numba
import numpy as np
X = np.random.normal(size=(5000, 6000))
#
# This is what we're trying to implement in Numba:
#
numpy_res = np.max(X, axis=1)
# This is what we're trying to match/beat:
# %timeit np.max(X, axis=1)
# 15.6 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
#
# NOTE: We do *not* want to enable `parallel=True`, `fastmath=True`, or any
# other options that are neither generalizable nor used by NumPy.
#
@numba.njit(inline="always")
def custom_max(x, y):
if x > y:
return x
else:
return y
#
# Case 1: Using `vectorize` along the reduced dimension
#
@numba.vectorize(["float64(float64, float64)"])
def vectorized_max(x, y):
return custom_max(x, y)
@numba.njit
def vectorized_max_reduce_axis_1(x):
res = np.full((x.shape[0],), -np.inf, dtype=x.dtype)
x_transpose = np.transpose(x)
for m in range(x.shape[1]):
vectorized_max(res, x_transpose[m], res)
return res
# Confirm that it works
assert np.array_equal(numpy_res, vectorized_max_reduce_axis_1(X))
# %timeit vectorized_max_reduce_axis_1(X)
# 94.2 ms ± 435 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#
# Case 2: Manually written Numba reduction loops
#
@numba.njit
def manual_max_reduce_axis_1(x):
res = np.full((x.shape[0],), -np.inf, dtype=x.dtype)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res[i] = custom_max(res[i], x[i, j])
return res
assert np.array_equal(numpy_res, manual_max_reduce_axis_1(X))
# %timeit manual_max_reduce_axis_1(X)
# 38.3 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#
# Case 3: Calls to NumPy's `ufunc.reduce` via `objmode`
#
@numba.njit
def objmode_max_reduce_axis_1(x):
with numba.objmode(res="float64[:]"):
res = vectorized_max.reduce(x, axis=1)
return res
assert np.array_equal(numpy_res, objmode_max_reduce_axis_1(X))
# %timeit objmode_max_reduce_axis_1(X)
# 73.7 ms ± 262 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#
# Case 4: Recompile Case 2 with more LLVM optimizations in the "cheap" pass
#
# This comes from the discussions in
# https://numba.discourse.group/t/numba-performance-doesnt-scale-as-well-as-numpy-in-vectorized-max-function/782
#
@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target
context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm
with use_optimized_cheap_pass():
@numba.njit
def opt_manual_max_reduce_axis_1(x):
res = np.full((x.shape[0],), -np.inf, dtype=x.dtype)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
res[i] = custom_max(res[i], x[i, j])
return res
assert np.array_equal(numpy_res, opt_manual_max_reduce_axis_1(X))
# %timeit opt_manual_max_reduce_axis_1(X)
# 37.6 ms ± 872 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#
# Case 5: This is Stuart's original form of the reduction
#
# Apparently, removing the use of `custom_max` and manually "in-lining" the
# `max` operation makes a large difference; however, we can't reasonably do
# this for every binary function that will perform a reduction.
#
with use_optimized_cheap_pass():
@numba.njit
def orig_opt_max_reduce_axis_1(x):
res = np.full((x.shape[0],), -np.inf, dtype=x.dtype)
for i in range(x.shape[0]):
for j in range(x.shape[1]):
tmp = x[i, j]
if res[i] < tmp:
res[i] = tmp
return res
assert np.array_equal(numpy_res, orig_opt_max_reduce_axis_1(X))
# %timeit orig_opt_max_reduce_axis_1(X)
# 20.3 ms ± 233 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
It looks like the “cheap” pass optimization hack only helps when the binary function is manually in-lined; however, we need to generalize these reduction loops by calling nearly arbitrary binary functions (e.g. like custom_max
in the example above), so we can’t reasonably in-line the binary operation as in @stuartarchibald’s example. The performance for that example is acceptable for our use-case, though, so, if we could get res[i] = custom_max(res[i], x[i, j])
to be treated more or less the same as
tmp = x[i, j]
if res[i] < tmp:
res[i] = tmp
we might be fine.