How to set type (precision) of operations inside a function?

This function:

@nb.njit('(f4)(f4)', fastmath=True, inline='always')
def f(a):
  return 42*a

uses float64 for internal operations and casts it back to the requested float32:

def f(a):

  # --- LINE 8 --- 
  #   $const4.0.1 = const(int, 42)  :: Literal[int](42)
  #   $binop_mul8.2 = $const4.0.1 * a  :: float64
  #   del a
  #   del $const4.0.1
  #   $12return_value.3 = cast(value=$binop_mul8.2)  :: float32
  #   del $binop_mul8.2
  #   return $12return_value.3

  return 42*a

Is there a way to set all floating point operation to 32-bit precision? Similarily, can integer operations be set to 32-bit precision?

Happy easter @pauljurczak ,

can you try defining Literal[int](42) as float32 via locals dict or explicit casting?

1 Like

This topic has come up more than once over the years. I don’t remember any great solutions beyond the suggestions to manage the types manually through casting.

2 Likes

This option:

@nb.njit('(f4)(f4)', fastmath=True, inline='always')
def f(a):
  return nb.float32(42.0)*a

still has float64 lurking:

  # --- LINE 8 --- 
  #   $4load_global.0 = global(nb: <module 'numba' from '/home/paul/upwork/pickleball/code/.venv/lib/python3.12/site-packages/numba/__init__.py'>)  :: Module(<module 'numba' from '/home/paul/upwork/pickleball/code/.venv/lib/python3.12/site-packages/numba/__init__.py'>)
  #   $14load_attr.2 = getattr(value=$4load_global.0, attr=float32)  :: class(float32)
  #   del $4load_global.0
  #   $const34.3.1 = const(float, 42.0)  :: float64
  #   $36call.4 = call $14load_attr.2($const34.3.1, func=$14load_attr.2, args=[Var($const34.3.1, so-59-numba.py:8)], kws=(), vararg=None, varkwarg=None, target=None)  :: (float64,) -> float32
  #   del $const34.3.1
  #   del $14load_attr.2
  #   $binop_mul46.6 = $36call.4 * a  :: float32
  #   del a
  #   del $36call.4
  #   $50return_value.7 = cast(value=$binop_mul46.6)  :: float32
  #   del $binop_mul46.6
  #   return $50return_value.7

  return nb.float32(42.0)*a

Is it affecting the runtime or only the compilation?

I’m not completely sure either @pauljurczak ,

Your version temporarily creates a float64 literal (42) and casts it to float32 before multiplication. While this cast doesn’t affect the actual computation (which is done in float32, as seen in the LLVM IR via the fmul instruction), it does add a bit of overhead during compilation and clutters the IR.

@nb.njit(["f4(f4)"], fastmath=True)
def f(a):
    mult = np.float32(42)
    return mult * a
# f.inspect_types()
# f.inspect_llvm()

Here are the infered types for f():

...
    #   $const34.3.1 = const(int, 42)  :: Literal[int](42)
    #   mult = call $14load_attr.1($const34.3.1, func=$14load_attr.1, args=[Var($const34.3.1, 3285166149.py:6)], kws=(), vararg=None, varkwarg=None, target=None)  :: (int64,) -> float32
    #   del $const34.3.1
    #   del $14load_attr.1

    mult = np.float32(42)

    # --- LINE 7 --- 
    #   $binop_mul50.7 = mult * a  :: float32
    #   del mult
    #   del a
    #   $54return_value.8 = cast(value=$binop_mul50.7)  :: float32
    #   del $binop_mul50.7
    #   return $54return_value.8

    return mult * a

If you want to avoid that implicit cast and get a cleaner IR, you can initialize the constant as a float32 directly via the locals dictionary:

@nb.njit(["f4(f4)"], locals={"mult": nb.f4}, fastmath=True)
def g(a):
    mult = 42
    return mult*a
# g.inspect_types()
# g.inspect_llvm()

Here are the infered types for g():

def g(a):

    # --- LINE 13 --- 
    #   mult = const(int, 42)  :: float32

    mult = 42

    # --- LINE 14 --- 
    #   $binop_mul12.3 = mult * a  :: float32
    #   del mult
    #   del a
    #   $16return_value.4 = cast(value=$binop_mul12.3)  :: float32
    #   del $binop_mul12.3
    #   return $16return_value.4

    return mult*a

The result is a cleaner LLVM IR and slightly faster compile times, but at runtime both versions could perform similarly since the computation itself stays in float32.

2 Likes