 # [Question] Math typing in CUDA

This post is a response to a question posted on Gitter:

I have discovered that [my code] is compute bound, and most of the math is being done in double precision. This was a bit surprising to me since all the values being passed into the kernel were single precision. I have since discovered that the `math` and `cmath` libraries seem to always output doubles. I am using functions like `sin`, `arctan2`, `cmath.exp`. Are there single precision variants available somewhere?

# Typing of `math` and `cmath` functions

If the `math` and `cmath` functions are presented with inputs that are single-precision, then they should produce single-precision outputs. Considering the following example:

``````from numba import cuda, float32
import numpy as np
import math
import cmath

@cuda.jit
def f(x, xc):
i = cuda.grid(1)

if i < len(x):
s = math.sin(x[i])
a = math.atan2(x[i], float32(0.5))
x[i] = s * a

e = cmath.exp(x[i])
xc[i] = e * s

x = np.random.random(10).astype(np.float32)
xc = np.zeros_like(x).astype(np.complex64)
print(x)

f[1, 32](x, xc)

print(x)
print(xc)
``````

then running it with the Numba CLI tool to produce an annotated HTML file containing the typing:

``````\$ numba --annotate-html mathtyping.html repro.py
``````

shows that the functions produce `float32` outputs when given `float32` input. For example for the call to `math.sin`, the relevant IR is:

``````\$24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
\$32binary_subscr.4 = getitem(value=x, index=i) :: float32
\$34call_method.5 = call \$26load_method.1(\$32binary_subscr.4, func=\$26load_method.1, args=[Var(\$32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
s = \$34call_method.5 :: float32
``````

The method definition itself (`\$34call_method.5`) is typed as taking a `float32` and returning a `float32` (`:: (float32,) -> float32`), and the result of calling the method (`s = \$34call_method.5 :: float32`) shows the assignment of the `float32` type to the variable `s`.

Similar behaviour occurs for the other functions, and for complex math operations - if the inputs are `complex64`, then the outputs should be `complex64` too - I’ll attach a complete IR dump at the end of this post for further inspection.

# Potential issue / solution

It may that the root of your problem is that the functions you’re calling are already being presented with `float64` inputs, which can happen if you have intermediate computations that result in a `float64` being produced, or constants in the code that are not explicitly typed as `float32`.

## Typing of constants

For example, if I change the call to `math.atan2` such that it becomes:

``````        a = math.atan2(x[i], 0.5)
``````

then the IR is now:

``````\$38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
\$46binary_subscr.10 = getitem(value=x, index=i) :: float32
\$const48.11 = const(float, 0.5) :: float64
\$50call_method.12 = call \$40load_method.7(\$46binary_subscr.10, \$const48.11, func=\$40load_method.7, args=[Var(\$46binary_subscr.10, repro.py:13), Var(\$const48.11, repro.py:13)], kws=(), vararg=None) :: (float64, float64) -> float64
a = \$50call_method.12 :: float64
``````

What happened: the constant `0.5` is typed as `float64` (`\$const48.11 = const(float, 0.5) :: float64`), and the typing of `\$50call_method.12` is such that the version that accepts two `float64` operands and returns a `float64` is selected (`:: (float64, float64) -> float64`). Therefore, the computation is done using `float64`, and the type of `a` is also `float64` (`a = \$50call_method.12 :: float64`).

## Typing of arithmetic operations

Some arithmetic operations on `float32` and other operands can also result in `float64` results. In general I think most operations involving only `float32` should also produce `float32`, but operations combining integers and `float32` can result in `float64`. Supposing I add a line to the example:

``````         s = s + int32(1)
``````

Then the IR for this is:

``````\$40load_global.7 = global(int32: int32) :: class(int32)
\$const42.8 = const(int, 1) :: Literal[int](1)
\$44call_function.9 = call \$40load_global.7(\$const42.8, func=\$40load_global.7, args=[Var(\$const42.8, reproint.py:13)], kws=(), vararg=None) :: (int64,) -> int32
\$46binary_add.10 = s + \$44call_function.9 :: float64
``````

The addition of a `float32` (`s`) and `int32` (`\$44call_function.9`) results in a `float64` (`\$46binary_add.10 = s + \$44call_function.9 :: float64`). This is a policy decision embedded in Numba - a `float32` only has about 7 significant figures of precision, which is not enough to represent values in the range of an `int32`, so the operation is computed with `float64`, with about 16 significant figures.

## General strategy

The introduction of `float64` early on in a function tends to propagate all the way through it, causing the majority of operations to be computed with double precision arithmetic. In general my strategy for dealing with this is to:

1. Make sure all constants are explicitly typed - `float32` where possible.
2. Inspect the typing using the Numba CLI tool.
3. Make modifications and add casts to `float32` where possible when the IR indicates that `float64` values are being produced.
4. Repeat steps 2-3 as necessary, until there are no more `float64` variables, or until it’s not possible to get rid of any more.

For floating point arithmetic, I can usually be successful in ensuring that all arithmetic happens in single precision. For integer arithmetic (not really covered much in this post or the question, but see Numba Enhancement Proposal 1 for more details), it is harder to ensure that all operations are done on 32-bit operands - sometimes it is possible, but integer arithmetic is very prone to widening in Numba. A future addition (currently a “nice-to-have” / “wishlist”) item would provide an alternative typing mode that tries to keep operand widths narrow, but the work to do this is presently unscheduled.

# Full annotated IR dump

This is the output of `numba --annotate repro.py` for the example code above, but using the `--annotate-html` option is generally preferable as you can explore the IR line-by-line more easily in the browser.

``````-----------------------------------ANNOTATION-----------------------------------
# File: repro.py
# --- LINE 7 ---

@cuda.jit

# --- LINE 8 ---

def f(x, xc):

# --- LINE 9 ---
# label 0
#   x = arg(0, name=x)  :: array(float32, 1d, C)
#   xc = arg(1, name=xc)  :: array(complex64, 1d, C)
#   \$const6.2 = const(int, 1)  :: Literal[int](1)
#   \$8call_method.3 = call \$4load_method.1(\$const6.2, func=\$4load_method.1, args=[Var(\$const6.2, repro.py:9)], kws=(), vararg=None)  :: (int32,) -> int32
#   del \$const6.2
#   i = \$8call_method.3  :: int32
#   del \$8call_method.3

i = cuda.grid(1)

# --- LINE 10 ---

# --- LINE 11 ---
#   \$14load_global.5 = global(len: <built-in function len>)  :: Function(<built-in function len>)
#   \$18call_function.7 = call \$14load_global.5(x, func=\$14load_global.5, args=[Var(x, repro.py:9)], kws=(), vararg=None)  :: (array(float32, 1d, C),) -> int64
#   \$20compare_op.8 = i < \$18call_function.7  :: bool
#   del \$18call_function.7
#   bool22 = global(bool: <class 'bool'>)  :: Function(<class 'bool'>)
#   \$22pred = call bool22(\$20compare_op.8, func=bool22, args=(Var(\$20compare_op.8, repro.py:11),), kws=(), vararg=None)  :: (bool,) -> bool
#   del bool22
#   del \$20compare_op.8
#   branch \$22pred, 24, 96

if i < len(x):

# --- LINE 12 ---
# label 24
#   del \$22pred
#   \$32binary_subscr.4 = getitem(value=x, index=i)  :: float32
#   \$34call_method.5 = call \$26load_method.1(\$32binary_subscr.4, func=\$26load_method.1, args=[Var(\$32binary_subscr.4, repro.py:12)], kws=(), vararg=None)  :: (float32,) -> float32
#   del \$32binary_subscr.4
#   s = \$34call_method.5  :: float32
#   del \$34call_method.5

s = math.sin(x[i])

# --- LINE 13 ---
#   \$46binary_subscr.10 = getitem(value=x, index=i)  :: float32
#   \$48load_global.11 = global(float32: float32)  :: class(float32)
#   \$const50.12 = const(float, 0.5)  :: float64
#   \$52call_function.13 = call \$48load_global.11(\$const50.12, func=\$48load_global.11, args=[Var(\$const50.12, repro.py:13)], kws=(), vararg=None)  :: (float64,) -> float32
#   del \$const50.12
#   \$54call_method.14 = call \$40load_method.7(\$46binary_subscr.10, \$52call_function.13, func=\$40load_method.7, args=[Var(\$46binary_subscr.10, repro.py:13), Var(\$52call_function.13, repro.py:13)], kws=(), vararg=None)  :: (float32, float32) -> float32
#   del \$52call_function.13
#   del \$46binary_subscr.10
#   a = \$54call_method.14  :: float32
#   del \$54call_method.14

a = math.atan2(x[i], float32(0.5))

# --- LINE 14 ---
#   \$62binary_multiply.17 = s * a  :: float32
#   del a
#   x[i] = \$62binary_multiply.17  :: (array(float32, 1d, C), int64, float32) -> none
#   del \$62binary_multiply.17

x[i] = s * a

# --- LINE 15 ---

# --- LINE 16 ---
#   \$78binary_subscr.24 = getitem(value=x, index=i)  :: float32
#   del x
#   \$80call_method.25 = call \$72load_method.21(\$78binary_subscr.24, func=\$72load_method.21, args=[Var(\$78binary_subscr.24, repro.py:16)], kws=(), vararg=None)  :: (complex64,) -> complex64
#   del \$78binary_subscr.24
#   e = \$80call_method.25  :: complex64
#   del \$80call_method.25

e = cmath.exp(x[i])

# --- LINE 17 ---
#   \$88binary_multiply.28 = e * s  :: complex64
#   del s
#   del e
#   xc[i] = \$88binary_multiply.28  :: (array(complex64, 1d, C), int64, complex64) -> none
#   del xc
#   del i
#   del \$88binary_multiply.28
#   jump 96
# label 96
#   del xc
#   del x
#   del i
#   del \$22pred
#   \$const96.0 = const(NoneType, None)  :: none
#   \$98return_value.1 = cast(value=\$const96.0)  :: none
#   del \$const96.0
#   return \$98return_value.1

xc[i] = e * s
``````
2 Likes

This is a great explanation for a commonly encountered issue on the `CUDA` target, thanks for the write up @gmarkall!

``````import math
from numba import njit, float32

a = float32(60.0)
b = float32(2.3)

@njit
def funct(a,b):
return math.sin(a / b)

c = funct(a,b)
print(type(c))
print(c)
``````

This prints the output:

``````<class 'float'>
0.8158610507428679
``````

Meanwhile, running it using the Numba CLI and dumping to HTML gives the signature:

``````Function name: funct
in file: typetest.py
with signature: (float32, float32) -> float32
``````

So! Here we have a very annoying contradiction. Numba is reporting that a float32 output from the function, while python itself is reporting the output as float64. Is this an issue with “type()”?

There is no contradiction here - what you are seeing is the difference between native types used in Numba-compiled code, and objects used in the Python interpreter.

• In the Numba-compiled kernel, the data types are all `float32`, and the computation is done with `float32`, and the output value is `float32`. This is a native value, which is just the 32-bit floating point representation.
• Python doesn’t represent floating point values using the native representation, but uses an object instead. When the Numba-compiled function returns, it must box the native value in a Python object, so that it can be handled correctly by the Python interpreter. For floating point values in Python, there is only the `float` class, so this is what `float32` must be boxed as.

If you were to use a NumPy array instead:

``````from numba import njit, float32
import numpy as np

a = float32(60.0)
b = float32(2.3)

@njit
def funct(a, b):
ret = np.zeros(1, dtype=np.float32)
ret = np.sin(a / b)
return ret

c = funct(a, b)
print(c.dtype)
print(c)
``````

then you would see:

``````float32
[0.81586105]
``````

In this case the returned value retains the 32-bit data type.

(Note that I had to use an array rather than a NumPy scalar for this example, due to https://github.com/numba/numba/issues/5990 - Numba doesn’t differentiate between NumPy scalars and CPython scalars, and treats them all as if they are CPython scalars)