How can fused multiply-add (fma) be enforced in jit-compiled code?

This question is about Fused Multiply-Add (fma). It is related to Numba-Issue 10071, but has a slightly different scope.
Rather than supporting Python’s relatively new builtin math.fma, which computes a * b + c in a fused operation with only a single rounding, I want to ask if it is possible to enforce certain parts of jit-compiled code to use fma rather than letting the compiler decide? Many modern CPUs natively support fma instructions and thus math.fma support would not even be necessary for this purpose.

The LLVM fastmath flags have the AllowContract flag that allows LLVM to possibly use fma, but I find the results rather inconsistent when looking at the assembly code. Depending on where the particular expressions is located, the code is only sometimes compiled to instructions with fma, probably based on some compiler optimization heuristics.

This makes perfect sense from a performance perspective. However, some algorithms, like floating-point error compensated dot products, explicitly require the use of fma for correctness of the result rather than performance. Trying the fastmath flags is then unreliable because it does not guarantee that fma is used.

What I attempted

I tried to implement my own jit-compatible fma64 function based on the Numba documentation on how to implement intrinsics (see code below). It was more of an AI+search-engine-assisted trial-and-error process to get to this implementation, but it seems to work reliably as exemplified in code example in a comment to the aforementioned GitHub issue.

Is this a reasonable way to do it and is my implementation aligned with good numba practices? I would highly appreciate feedback because I’m not so deep into LLVM and intrinsics. If not, what are other ways to enforce the use of fma?

"""
This module implements a ``numba``-compatible fused multiply-add (FMA) operation that
does not solely rely on LLVM considering fusing operations on its own given the
respective ``fastmath`` flags, but directly invokes the LLVM intrinsic function
``llvm.fma.f64``.

"""

# --- Imports ---

from llvmlite import ir
from numba import njit
from numba import types as numba_types
from numba.extending import intrinsic

# --- Functions ---


@intrinsic
def fma64(typingctx, a, b, c):
    if a == b == c != numba_types.float64:
        raise TypeError("fma64 expects three float64s")

    sig = numba_types.float64(a, b, c)

    def codegen(context, builder, signature, args):
        mod = builder.module
        try:
            fn = mod.get_global("llvm.fma.f64")
        except KeyError:
            fty = ir.FunctionType(
                return_type=ir.DoubleType(),
                args=(
                    ir.DoubleType(),
                    ir.DoubleType(),
                    ir.DoubleType(),
                ),
            )
            fn = ir.Function(mod, fty, name="llvm.fma.f64")

        return builder.call(fn, args)

    return sig, codegen

@njit(
    numba_types.float64(
        numba_types.float64,
        numba_types.float64,
        numba_types.float64,
    ),
    inline="always",
)
def fma(
    a: float,
    b: float,
    c: float,
) -> float:
    """
    Computes ``a * b + c`` using a fused multiply-add (FMA) operation.

    """

    return fma64(a, b, c)  # type: ignore


# --- Main ---

if __name__ == "__main__":

    a = 1.5
    b = 2.5
    c = 3.5

    print(fma(a, b, c))  # Expected output: 7.25

    assert any(
        ["fmadd" in value for value in fma.inspect_asm().values()]  # type: ignore
    ), "FMA not found in assembly!"

Only issue I see is that your type checking is a bit off. For instance (int, int, float) and (float, float, int) would not trigger the error in the first line. It might save you some headaches in the long-term to just put in a cast to float64 for each argument. Sorry, it’s been a sec since I wrote an intrinsic so I don’t have an example on hand.

@DannyWeitekamp Thanks for the feedback!
Do I understand correctly that you would basically recommend to do the following?

@njit(
    inline="always",
)
def fma(
  a: float,
  b: float,
  c: float,
) -> float:
"""
Computes  ``a * b + c`` using a fused multiply-add (FMA) operation.

"""

  return fma64(np.float64(a), np.float64(b), np.float64(c))  # type: ignore

Will this have performance implications when the input is already float64?

Just to follow up on the casting suggestion by @DannyWeitekamp, you can also incorporate it as a sitofp instruction (which is what float64(…) is lowered into) into your intrinsic:

from numba.core.cgutils import get_or_insert_function

@intrinsic
def fma64(typingctx, a_ty, b_ty, c_ty):
    def codegen(context, builder, signature, args):
        fty = ir.FunctionType(return_type=ir.DoubleType(), args=[ir.DoubleType() for _ in range(3)])
        fn = get_or_insert_function(builder.module, fty, "llvm.fma.f64")
        args = [builder.sitofp(arg, ir.DoubleType()) for arg in args]
        return builder.call(fn, args)
    return numba_types.float64(a_ty, b_ty, c_ty), codegen

This is just for general reference, probably not generic enough for your use case.

@milton has added spot on what I was suggesting. As for a potential performance hit, there almost certainly would not be any performance drop if the ingoing type was already float64, the compiler should optimize that away as a no-op. In any case, in my experience casting numerical types tends to have negligible overhead, but as always if you want to know exactly how much overhead you can run some speed tests to profile it.

@milton @DannyWeitekamp
Thank you so much for the replies! I will try it out.