This question is about Fused Multiply-Add (fma). It is related to Numba-Issue 10071, but has a slightly different scope.
Rather than supporting Python’s relatively new builtin math.fma, which computes a * b + c in a fused operation with only a single rounding, I want to ask if it is possible to enforce certain parts of jit-compiled code to use fma rather than letting the compiler decide? Many modern CPUs natively support fma instructions and thus math.fma support would not even be necessary for this purpose.
The LLVM fastmath flags have the AllowContract flag that allows LLVM to possibly use fma, but I find the results rather inconsistent when looking at the assembly code. Depending on where the particular expressions is located, the code is only sometimes compiled to instructions with fma, probably based on some compiler optimization heuristics.
This makes perfect sense from a performance perspective. However, some algorithms, like floating-point error compensated dot products, explicitly require the use of fma for correctness of the result rather than performance. Trying the fastmath flags is then unreliable because it does not guarantee that fma is used.
What I attempted
I tried to implement my own jit-compatible fma64 function based on the Numba documentation on how to implement intrinsics (see code below). It was more of an AI+search-engine-assisted trial-and-error process to get to this implementation, but it seems to work reliably as exemplified in code example in a comment to the aforementioned GitHub issue.
Is this a reasonable way to do it and is my implementation aligned with good numba practices? I would highly appreciate feedback because I’m not so deep into LLVM and intrinsics. If not, what are other ways to enforce the use of fma?
"""
This module implements a ``numba``-compatible fused multiply-add (FMA) operation that
does not solely rely on LLVM considering fusing operations on its own given the
respective ``fastmath`` flags, but directly invokes the LLVM intrinsic function
``llvm.fma.f64``.
"""
# --- Imports ---
from llvmlite import ir
from numba import njit
from numba import types as numba_types
from numba.extending import intrinsic
# --- Functions ---
@intrinsic
def fma64(typingctx, a, b, c):
if a == b == c != numba_types.float64:
raise TypeError("fma64 expects three float64s")
sig = numba_types.float64(a, b, c)
def codegen(context, builder, signature, args):
mod = builder.module
try:
fn = mod.get_global("llvm.fma.f64")
except KeyError:
fty = ir.FunctionType(
return_type=ir.DoubleType(),
args=(
ir.DoubleType(),
ir.DoubleType(),
ir.DoubleType(),
),
)
fn = ir.Function(mod, fty, name="llvm.fma.f64")
return builder.call(fn, args)
return sig, codegen
@njit(
numba_types.float64(
numba_types.float64,
numba_types.float64,
numba_types.float64,
),
inline="always",
)
def fma(
a: float,
b: float,
c: float,
) -> float:
"""
Computes ``a * b + c`` using a fused multiply-add (FMA) operation.
"""
return fma64(a, b, c) # type: ignore
# --- Main ---
if __name__ == "__main__":
a = 1.5
b = 2.5
c = 3.5
print(fma(a, b, c)) # Expected output: 7.25
assert any(
["fmadd" in value for value in fma.inspect_asm().values()] # type: ignore
), "FMA not found in assembly!"