[Question] Math typing in CUDA

This post is a response to a question posted on Gitter:

I have discovered that [my code] is compute bound, and most of the math is being done in double precision. This was a bit surprising to me since all the values being passed into the kernel were single precision. I have since discovered that the math and cmath libraries seem to always output doubles. I am using functions like sin, arctan2, cmath.exp. Are there single precision variants available somewhere?

Typing of math and cmath functions

If the math and cmath functions are presented with inputs that are single-precision, then they should produce single-precision outputs. Considering the following example:

from numba import cuda, float32
import numpy as np
import math
import cmath

def f(x, xc):
    i = cuda.grid(1)

    if i < len(x):
        s = math.sin(x[i])
        a = math.atan2(x[i], float32(0.5))
        x[i] = s * a

        e = cmath.exp(x[i])
        xc[i] = e * s

x = np.random.random(10).astype(np.float32)
xc = np.zeros_like(x).astype(np.complex64)

f[1, 32](x, xc)


then running it with the Numba CLI tool to produce an annotated HTML file containing the typing:

$ numba --annotate-html mathtyping.html repro.py

shows that the functions produce float32 outputs when given float32 input. For example for the call to math.sin, the relevant IR is:

$24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$26load_method.1 = getattr(value=$24load_global.0, attr=sin) :: Function(<built-in function sin>)
$32binary_subscr.4 = getitem(value=x, index=i) :: float32
$34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
s = $34call_method.5 :: float32

The method definition itself ($34call_method.5) is typed as taking a float32 and returning a float32 (:: (float32,) -> float32), and the result of calling the method (s = $34call_method.5 :: float32) shows the assignment of the float32 type to the variable s.

Similar behaviour occurs for the other functions, and for complex math operations - if the inputs are complex64, then the outputs should be complex64 too - I’ll attach a complete IR dump at the end of this post for further inspection.

Potential issue / solution

It may that the root of your problem is that the functions you’re calling are already being presented with float64 inputs, which can happen if you have intermediate computations that result in a float64 being produced, or constants in the code that are not explicitly typed as float32.

Typing of constants

For example, if I change the call to math.atan2 such that it becomes:

        a = math.atan2(x[i], 0.5)

then the IR is now:

$38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$40load_method.7 = getattr(value=$38load_global.6, attr=atan2) :: Function(<built-in function atan2>)
$46binary_subscr.10 = getitem(value=x, index=i) :: float32
$const48.11 = const(float, 0.5) :: float64
$50call_method.12 = call $40load_method.7($46binary_subscr.10, $const48.11, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($const48.11, repro.py:13)], kws=(), vararg=None) :: (float64, float64) -> float64
            a = $50call_method.12 :: float64

What happened: the constant 0.5 is typed as float64 ($const48.11 = const(float, 0.5) :: float64), and the typing of $50call_method.12 is such that the version that accepts two float64 operands and returns a float64 is selected (:: (float64, float64) -> float64). Therefore, the computation is done using float64, and the type of a is also float64 (a = $50call_method.12 :: float64).

Typing of arithmetic operations

Some arithmetic operations on float32 and other operands can also result in float64 results. In general I think most operations involving only float32 should also produce float32, but operations combining integers and float32 can result in float64. Supposing I add a line to the example:

         s = s + int32(1)

Then the IR for this is:

$40load_global.7 = global(int32: int32) :: class(int32)
$const42.8 = const(int, 1) :: Literal[int](1)
$44call_function.9 = call $40load_global.7($const42.8, func=$40load_global.7, args=[Var($const42.8, reproint.py:13)], kws=(), vararg=None) :: (int64,) -> int32
$46binary_add.10 = s + $44call_function.9 :: float64
s.1 = $46binary_add.10 :: float64

The addition of a float32 (s) and int32 ($44call_function.9) results in a float64 ($46binary_add.10 = s + $44call_function.9 :: float64). This is a policy decision embedded in Numba - a float32 only has about 7 significant figures of precision, which is not enough to represent values in the range of an int32, so the operation is computed with float64, with about 16 significant figures.

General strategy

The introduction of float64 early on in a function tends to propagate all the way through it, causing the majority of operations to be computed with double precision arithmetic. In general my strategy for dealing with this is to:

  1. Make sure all constants are explicitly typed - float32 where possible.
  2. Inspect the typing using the Numba CLI tool.
  3. Make modifications and add casts to float32 where possible when the IR indicates that float64 values are being produced.
  4. Repeat steps 2-3 as necessary, until there are no more float64 variables, or until it’s not possible to get rid of any more.

For floating point arithmetic, I can usually be successful in ensuring that all arithmetic happens in single precision. For integer arithmetic (not really covered much in this post or the question, but see Numba Enhancement Proposal 1 for more details), it is harder to ensure that all operations are done on 32-bit operands - sometimes it is possible, but integer arithmetic is very prone to widening in Numba. A future addition (currently a “nice-to-have” / “wishlist”) item would provide an alternative typing mode that tries to keep operand widths narrow, but the work to do this is presently unscheduled.

Full annotated IR dump

This is the output of numba --annotate repro.py for the example code above, but using the --annotate-html option is generally preferable as you can explore the IR line-by-line more easily in the browser.

# File: repro.py
# --- LINE 7 ---


# --- LINE 8 ---

def f(x, xc):

    # --- LINE 9 ---
    # label 0
    #   x = arg(0, name=x)  :: array(float32, 1d, C)
    #   xc = arg(1, name=xc)  :: array(complex64, 1d, C)
    #   $2load_global.0 = global(cuda: <module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>)  :: Module(<module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>)
    #   $4load_method.1 = getattr(value=$2load_global.0, attr=grid)  :: Function(<class 'numba.cuda.stubs.grid'>)
    #   del $2load_global.0
    #   $const6.2 = const(int, 1)  :: Literal[int](1)
    #   $8call_method.3 = call $4load_method.1($const6.2, func=$4load_method.1, args=[Var($const6.2, repro.py:9)], kws=(), vararg=None)  :: (int32,) -> int32
    #   del $const6.2
    #   del $4load_method.1
    #   i = $8call_method.3  :: int32
    #   del $8call_method.3

    i = cuda.grid(1)

# --- LINE 10 ---

    # --- LINE 11 ---
    #   $14load_global.5 = global(len: <built-in function len>)  :: Function(<built-in function len>)
    #   $18call_function.7 = call $14load_global.5(x, func=$14load_global.5, args=[Var(x, repro.py:9)], kws=(), vararg=None)  :: (array(float32, 1d, C),) -> int64
    #   del $14load_global.5
    #   $20compare_op.8 = i < $18call_function.7  :: bool
    #   del $18call_function.7
    #   bool22 = global(bool: <class 'bool'>)  :: Function(<class 'bool'>)
    #   $22pred = call bool22($20compare_op.8, func=bool22, args=(Var($20compare_op.8, repro.py:11),), kws=(), vararg=None)  :: (bool,) -> bool
    #   del bool22
    #   del $20compare_op.8
    #   branch $22pred, 24, 96

    if i < len(x):

        # --- LINE 12 ---
        # label 24
        #   del $22pred
        #   $24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)  :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
        #   $26load_method.1 = getattr(value=$24load_global.0, attr=sin)  :: Function(<built-in function sin>)
        #   del $24load_global.0
        #   $32binary_subscr.4 = getitem(value=x, index=i)  :: float32
        #   $34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None)  :: (float32,) -> float32
        #   del $32binary_subscr.4
        #   del $26load_method.1
        #   s = $34call_method.5  :: float32
        #   del $34call_method.5

        s = math.sin(x[i])

        # --- LINE 13 ---
        #   $38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)  :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
        #   $40load_method.7 = getattr(value=$38load_global.6, attr=atan2)  :: Function(<built-in function atan2>)
        #   del $38load_global.6
        #   $46binary_subscr.10 = getitem(value=x, index=i)  :: float32
        #   $48load_global.11 = global(float32: float32)  :: class(float32)
        #   $const50.12 = const(float, 0.5)  :: float64
        #   $52call_function.13 = call $48load_global.11($const50.12, func=$48load_global.11, args=[Var($const50.12, repro.py:13)], kws=(), vararg=None)  :: (float64,) -> float32
        #   del $const50.12
        #   del $48load_global.11
        #   $54call_method.14 = call $40load_method.7($46binary_subscr.10, $52call_function.13, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($52call_function.13, repro.py:13)], kws=(), vararg=None)  :: (float32, float32) -> float32
        #   del $52call_function.13
        #   del $46binary_subscr.10
        #   del $40load_method.7
        #   a = $54call_method.14  :: float32
        #   del $54call_method.14

        a = math.atan2(x[i], float32(0.5))

        # --- LINE 14 ---
        #   $62binary_multiply.17 = s * a  :: float32
        #   del a
        #   x[i] = $62binary_multiply.17  :: (array(float32, 1d, C), int64, float32) -> none
        #   del $62binary_multiply.17

        x[i] = s * a

# --- LINE 15 ---

        # --- LINE 16 ---
        #   $70load_global.20 = global(cmath: <module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>)  :: Module(<module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>)
        #   $72load_method.21 = getattr(value=$70load_global.20, attr=exp)  :: Function(<built-in function exp>)
        #   del $70load_global.20
        #   $78binary_subscr.24 = getitem(value=x, index=i)  :: float32
        #   del x
        #   $80call_method.25 = call $72load_method.21($78binary_subscr.24, func=$72load_method.21, args=[Var($78binary_subscr.24, repro.py:16)], kws=(), vararg=None)  :: (complex64,) -> complex64
        #   del $78binary_subscr.24
        #   del $72load_method.21
        #   e = $80call_method.25  :: complex64
        #   del $80call_method.25

        e = cmath.exp(x[i])

        # --- LINE 17 ---
        #   $88binary_multiply.28 = e * s  :: complex64
        #   del s
        #   del e
        #   xc[i] = $88binary_multiply.28  :: (array(complex64, 1d, C), int64, complex64) -> none
        #   del xc
        #   del i
        #   del $88binary_multiply.28
        #   jump 96
        # label 96
        #   del xc
        #   del x
        #   del i
        #   del $22pred
        #   $const96.0 = const(NoneType, None)  :: none
        #   $98return_value.1 = cast(value=$const96.0)  :: none
        #   del $const96.0
        #   return $98return_value.1

        xc[i] = e * s

This is a great explanation for a commonly encountered issue on the CUDA target, thanks for the write up @gmarkall!

import math
from numba import njit, float32

a = float32(60.0)
b = float32(2.3)

def funct(a,b):
    return math.sin(a / b)

c = funct(a,b)

This prints the output:

<class 'float'>

Meanwhile, running it using the Numba CLI and dumping to HTML gives the signature:

Function name: funct
in file: typetest.py
with signature: (float32, float32) -> float32 

So! Here we have a very annoying contradiction. Numba is reporting that a float32 output from the function, while python itself is reporting the output as float64. Is this an issue with “type()”?

There is no contradiction here - what you are seeing is the difference between native types used in Numba-compiled code, and objects used in the Python interpreter.

  • In the Numba-compiled kernel, the data types are all float32, and the computation is done with float32, and the output value is float32. This is a native value, which is just the 32-bit floating point representation.
  • Python doesn’t represent floating point values using the native representation, but uses an object instead. When the Numba-compiled function returns, it must box the native value in a Python object, so that it can be handled correctly by the Python interpreter. For floating point values in Python, there is only the float class, so this is what float32 must be boxed as.

If you were to use a NumPy array instead:

from numba import njit, float32
import numpy as np

a = float32(60.0)
b = float32(2.3)

def funct(a, b):
    ret = np.zeros(1, dtype=np.float32)
    ret[0] = np.sin(a / b)
    return ret

c = funct(a, b)

then you would see:


In this case the returned value retains the 32-bit data type.

(Note that I had to use an array rather than a NumPy scalar for this example, due to https://github.com/numba/numba/issues/5990 - Numba doesn’t differentiate between NumPy scalars and CPython scalars, and treats them all as if they are CPython scalars)