This post is a response to a question posted on Gitter:

I have discovered that [my code] is compute bound, and most of the math is being done in double precision. This was a bit surprising to me since all the values being passed into the kernel were single precision. I have since discovered that the

`math`

and`cmath`

libraries seem to always output doubles. I am using functions like`sin`

,`arctan2`

,`cmath.exp`

. Are there single precision variants available somewhere?

# Typing of `math`

and `cmath`

functions

If the `math`

and `cmath`

functions are presented with inputs that are single-precision, then they should produce single-precision outputs. Considering the following example:

```
from numba import cuda, float32
import numpy as np
import math
import cmath
@cuda.jit
def f(x, xc):
i = cuda.grid(1)
if i < len(x):
s = math.sin(x[i])
a = math.atan2(x[i], float32(0.5))
x[i] = s * a
e = cmath.exp(x[i])
xc[i] = e * s
x = np.random.random(10).astype(np.float32)
xc = np.zeros_like(x).astype(np.complex64)
print(x)
f[1, 32](x, xc)
print(x)
print(xc)
```

then running it with the Numba CLI tool to produce an annotated HTML file containing the typing:

```
$ numba --annotate-html mathtyping.html repro.py
```

shows that the functions produce `float32`

outputs when given `float32`

input. For example for the call to `math.sin`

, the relevant IR is:

```
$24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$26load_method.1 = getattr(value=$24load_global.0, attr=sin) :: Function(<built-in function sin>)
$32binary_subscr.4 = getitem(value=x, index=i) :: float32
$34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
s = $34call_method.5 :: float32
```

The method definition itself (`$34call_method.5`

) is typed as taking a `float32`

and returning a `float32`

(`:: (float32,) -> float32`

), and the result of calling the method (`s = $34call_method.5 :: float32`

) shows the assignment of the `float32`

type to the variable `s`

.

Similar behaviour occurs for the other functions, and for complex math operations - if the inputs are `complex64`

, then the outputs should be `complex64`

too - Iâ€™ll attach a complete IR dump at the end of this post for further inspection.

# Potential issue / solution

It may that the root of your problem is that the functions youâ€™re calling are already being presented with `float64`

inputs, which can happen if you have intermediate computations that result in a `float64`

being produced, or constants in the code that are not explicitly typed as `float32`

.

## Typing of constants

For example, if I change the call to `math.atan2`

such that it becomes:

```
a = math.atan2(x[i], 0.5)
```

then the IR is now:

```
$38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$40load_method.7 = getattr(value=$38load_global.6, attr=atan2) :: Function(<built-in function atan2>)
$46binary_subscr.10 = getitem(value=x, index=i) :: float32
$const48.11 = const(float, 0.5) :: float64
$50call_method.12 = call $40load_method.7($46binary_subscr.10, $const48.11, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($const48.11, repro.py:13)], kws=(), vararg=None) :: (float64, float64) -> float64
a = $50call_method.12 :: float64
```

What happened: the constant `0.5`

is typed as `float64`

(`$const48.11 = const(float, 0.5) :: float64`

), and the typing of `$50call_method.12`

is such that the version that accepts two `float64`

operands and returns a `float64`

is selected (`:: (float64, float64) -> float64`

). Therefore, the computation is done using `float64`

, and the type of `a`

is also `float64`

(`a = $50call_method.12 :: float64`

).

## Typing of arithmetic operations

Some arithmetic operations on `float32`

and other operands can also result in `float64`

results. In general I think most operations involving only `float32`

should also produce `float32`

, but operations combining integers and `float32`

can result in `float64`

. Supposing I add a line to the example:

```
s = s + int32(1)
```

Then the IR for this is:

```
$40load_global.7 = global(int32: int32) :: class(int32)
$const42.8 = const(int, 1) :: Literal[int](1)
$44call_function.9 = call $40load_global.7($const42.8, func=$40load_global.7, args=[Var($const42.8, reproint.py:13)], kws=(), vararg=None) :: (int64,) -> int32
$46binary_add.10 = s + $44call_function.9 :: float64
s.1 = $46binary_add.10 :: float64
```

The addition of a `float32`

(`s`

) and `int32`

(`$44call_function.9`

) results in a `float64`

(`$46binary_add.10 = s + $44call_function.9 :: float64`

). This is a policy decision embedded in Numba - a `float32`

only has about 7 significant figures of precision, which is not enough to represent values in the range of an `int32`

, so the operation is computed with `float64`

, with about 16 significant figures.

## General strategy

The introduction of `float64`

early on in a function tends to propagate all the way through it, causing the majority of operations to be computed with double precision arithmetic. In general my strategy for dealing with this is to:

- Make sure all constants are explicitly typed -
`float32`

where possible. - Inspect the typing using the Numba CLI tool.
- Make modifications and add casts to
`float32`

where possible when the IR indicates that`float64`

values are being produced. - Repeat steps 2-3 as necessary, until there are no more
`float64`

variables, or until itâ€™s not possible to get rid of any more.

For floating point arithmetic, I can usually be successful in ensuring that all arithmetic happens in single precision. For integer arithmetic (not really covered much in this post or the question, but see Numba Enhancement Proposal 1 for more details), it is harder to ensure that all operations are done on 32-bit operands - sometimes it is possible, but integer arithmetic is very prone to widening in Numba. A future addition (currently a â€śnice-to-haveâ€ť / â€świshlistâ€ť) item would provide an alternative typing mode that tries to keep operand widths narrow, but the work to do this is presently unscheduled.

# Full annotated IR dump

This is the output of `numba --annotate repro.py`

for the example code above, but using the `--annotate-html`

option is generally preferable as you can explore the IR line-by-line more easily in the browser.

```
-----------------------------------ANNOTATION-----------------------------------
# File: repro.py
# --- LINE 7 ---
@cuda.jit
# --- LINE 8 ---
def f(x, xc):
# --- LINE 9 ---
# label 0
# x = arg(0, name=x) :: array(float32, 1d, C)
# xc = arg(1, name=xc) :: array(complex64, 1d, C)
# $2load_global.0 = global(cuda: <module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>) :: Module(<module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>)
# $4load_method.1 = getattr(value=$2load_global.0, attr=grid) :: Function(<class 'numba.cuda.stubs.grid'>)
# del $2load_global.0
# $const6.2 = const(int, 1) :: Literal[int](1)
# $8call_method.3 = call $4load_method.1($const6.2, func=$4load_method.1, args=[Var($const6.2, repro.py:9)], kws=(), vararg=None) :: (int32,) -> int32
# del $const6.2
# del $4load_method.1
# i = $8call_method.3 :: int32
# del $8call_method.3
i = cuda.grid(1)
# --- LINE 10 ---
# --- LINE 11 ---
# $14load_global.5 = global(len: <built-in function len>) :: Function(<built-in function len>)
# $18call_function.7 = call $14load_global.5(x, func=$14load_global.5, args=[Var(x, repro.py:9)], kws=(), vararg=None) :: (array(float32, 1d, C),) -> int64
# del $14load_global.5
# $20compare_op.8 = i < $18call_function.7 :: bool
# del $18call_function.7
# bool22 = global(bool: <class 'bool'>) :: Function(<class 'bool'>)
# $22pred = call bool22($20compare_op.8, func=bool22, args=(Var($20compare_op.8, repro.py:11),), kws=(), vararg=None) :: (bool,) -> bool
# del bool22
# del $20compare_op.8
# branch $22pred, 24, 96
if i < len(x):
# --- LINE 12 ---
# label 24
# del $22pred
# $24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
# $26load_method.1 = getattr(value=$24load_global.0, attr=sin) :: Function(<built-in function sin>)
# del $24load_global.0
# $32binary_subscr.4 = getitem(value=x, index=i) :: float32
# $34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
# del $32binary_subscr.4
# del $26load_method.1
# s = $34call_method.5 :: float32
# del $34call_method.5
s = math.sin(x[i])
# --- LINE 13 ---
# $38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
# $40load_method.7 = getattr(value=$38load_global.6, attr=atan2) :: Function(<built-in function atan2>)
# del $38load_global.6
# $46binary_subscr.10 = getitem(value=x, index=i) :: float32
# $48load_global.11 = global(float32: float32) :: class(float32)
# $const50.12 = const(float, 0.5) :: float64
# $52call_function.13 = call $48load_global.11($const50.12, func=$48load_global.11, args=[Var($const50.12, repro.py:13)], kws=(), vararg=None) :: (float64,) -> float32
# del $const50.12
# del $48load_global.11
# $54call_method.14 = call $40load_method.7($46binary_subscr.10, $52call_function.13, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($52call_function.13, repro.py:13)], kws=(), vararg=None) :: (float32, float32) -> float32
# del $52call_function.13
# del $46binary_subscr.10
# del $40load_method.7
# a = $54call_method.14 :: float32
# del $54call_method.14
a = math.atan2(x[i], float32(0.5))
# --- LINE 14 ---
# $62binary_multiply.17 = s * a :: float32
# del a
# x[i] = $62binary_multiply.17 :: (array(float32, 1d, C), int64, float32) -> none
# del $62binary_multiply.17
x[i] = s * a
# --- LINE 15 ---
# --- LINE 16 ---
# $70load_global.20 = global(cmath: <module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>)
# $72load_method.21 = getattr(value=$70load_global.20, attr=exp) :: Function(<built-in function exp>)
# del $70load_global.20
# $78binary_subscr.24 = getitem(value=x, index=i) :: float32
# del x
# $80call_method.25 = call $72load_method.21($78binary_subscr.24, func=$72load_method.21, args=[Var($78binary_subscr.24, repro.py:16)], kws=(), vararg=None) :: (complex64,) -> complex64
# del $78binary_subscr.24
# del $72load_method.21
# e = $80call_method.25 :: complex64
# del $80call_method.25
e = cmath.exp(x[i])
# --- LINE 17 ---
# $88binary_multiply.28 = e * s :: complex64
# del s
# del e
# xc[i] = $88binary_multiply.28 :: (array(complex64, 1d, C), int64, complex64) -> none
# del xc
# del i
# del $88binary_multiply.28
# jump 96
# label 96
# del xc
# del x
# del i
# del $22pred
# $const96.0 = const(NoneType, None) :: none
# $98return_value.1 = cast(value=$const96.0) :: none
# del $const96.0
# return $98return_value.1
xc[i] = e * s
```