This post is a response to a question posted on Gitter:
I have discovered that [my code] is compute bound, and most of the math is being done in double precision. This was a bit surprising to me since all the values being passed into the kernel were single precision. I have since discovered that the
mathandcmathlibraries seem to always output doubles. I am using functions likesin,arctan2,cmath.exp. Are there single precision variants available somewhere?
Typing of math and cmath functions
If the math and cmath functions are presented with inputs that are single-precision, then they should produce single-precision outputs. Considering the following example:
from numba import cuda, float32
import numpy as np
import math
import cmath
@cuda.jit
def f(x, xc):
i = cuda.grid(1)
if i < len(x):
s = math.sin(x[i])
a = math.atan2(x[i], float32(0.5))
x[i] = s * a
e = cmath.exp(x[i])
xc[i] = e * s
x = np.random.random(10).astype(np.float32)
xc = np.zeros_like(x).astype(np.complex64)
print(x)
f[1, 32](x, xc)
print(x)
print(xc)
then running it with the Numba CLI tool to produce an annotated HTML file containing the typing:
$ numba --annotate-html mathtyping.html repro.py
shows that the functions produce float32 outputs when given float32 input. For example for the call to math.sin, the relevant IR is:
$24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$26load_method.1 = getattr(value=$24load_global.0, attr=sin) :: Function(<built-in function sin>)
$32binary_subscr.4 = getitem(value=x, index=i) :: float32
$34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
s = $34call_method.5 :: float32
The method definition itself ($34call_method.5) is typed as taking a float32 and returning a float32 (:: (float32,) -> float32), and the result of calling the method (s = $34call_method.5 :: float32) shows the assignment of the float32 type to the variable s.
Similar behaviour occurs for the other functions, and for complex math operations - if the inputs are complex64, then the outputs should be complex64 too - I’ll attach a complete IR dump at the end of this post for further inspection.
Potential issue / solution
It may that the root of your problem is that the functions you’re calling are already being presented with float64 inputs, which can happen if you have intermediate computations that result in a float64 being produced, or constants in the code that are not explicitly typed as float32.
Typing of constants
For example, if I change the call to math.atan2 such that it becomes:
a = math.atan2(x[i], 0.5)
then the IR is now:
$38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
$40load_method.7 = getattr(value=$38load_global.6, attr=atan2) :: Function(<built-in function atan2>)
$46binary_subscr.10 = getitem(value=x, index=i) :: float32
$const48.11 = const(float, 0.5) :: float64
$50call_method.12 = call $40load_method.7($46binary_subscr.10, $const48.11, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($const48.11, repro.py:13)], kws=(), vararg=None) :: (float64, float64) -> float64
a = $50call_method.12 :: float64
What happened: the constant 0.5 is typed as float64 ($const48.11 = const(float, 0.5) :: float64), and the typing of $50call_method.12 is such that the version that accepts two float64 operands and returns a float64 is selected (:: (float64, float64) -> float64). Therefore, the computation is done using float64, and the type of a is also float64 (a = $50call_method.12 :: float64).
Typing of arithmetic operations
Some arithmetic operations on float32 and other operands can also result in float64 results. In general I think most operations involving only float32 should also produce float32, but operations combining integers and float32 can result in float64. Supposing I add a line to the example:
s = s + int32(1)
Then the IR for this is:
$40load_global.7 = global(int32: int32) :: class(int32)
$const42.8 = const(int, 1) :: Literal[int](1)
$44call_function.9 = call $40load_global.7($const42.8, func=$40load_global.7, args=[Var($const42.8, reproint.py:13)], kws=(), vararg=None) :: (int64,) -> int32
$46binary_add.10 = s + $44call_function.9 :: float64
s.1 = $46binary_add.10 :: float64
The addition of a float32 (s) and int32 ($44call_function.9) results in a float64 ($46binary_add.10 = s + $44call_function.9 :: float64). This is a policy decision embedded in Numba - a float32 only has about 7 significant figures of precision, which is not enough to represent values in the range of an int32, so the operation is computed with float64, with about 16 significant figures.
General strategy
The introduction of float64 early on in a function tends to propagate all the way through it, causing the majority of operations to be computed with double precision arithmetic. In general my strategy for dealing with this is to:
- Make sure all constants are explicitly typed -
float32where possible. - Inspect the typing using the Numba CLI tool.
- Make modifications and add casts to
float32where possible when the IR indicates thatfloat64values are being produced. - Repeat steps 2-3 as necessary, until there are no more
float64variables, or until it’s not possible to get rid of any more.
For floating point arithmetic, I can usually be successful in ensuring that all arithmetic happens in single precision. For integer arithmetic (not really covered much in this post or the question, but see Numba Enhancement Proposal 1 for more details), it is harder to ensure that all operations are done on 32-bit operands - sometimes it is possible, but integer arithmetic is very prone to widening in Numba. A future addition (currently a “nice-to-have” / “wishlist”) item would provide an alternative typing mode that tries to keep operand widths narrow, but the work to do this is presently unscheduled.
Full annotated IR dump
This is the output of numba --annotate repro.py for the example code above, but using the --annotate-html option is generally preferable as you can explore the IR line-by-line more easily in the browser.
-----------------------------------ANNOTATION-----------------------------------
# File: repro.py
# --- LINE 7 ---
@cuda.jit
# --- LINE 8 ---
def f(x, xc):
# --- LINE 9 ---
# label 0
# x = arg(0, name=x) :: array(float32, 1d, C)
# xc = arg(1, name=xc) :: array(complex64, 1d, C)
# $2load_global.0 = global(cuda: <module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>) :: Module(<module 'numba.cuda' from '/home/gmarkall/numbadev/numba/numba/cuda/__init__.py'>)
# $4load_method.1 = getattr(value=$2load_global.0, attr=grid) :: Function(<class 'numba.cuda.stubs.grid'>)
# del $2load_global.0
# $const6.2 = const(int, 1) :: Literal[int](1)
# $8call_method.3 = call $4load_method.1($const6.2, func=$4load_method.1, args=[Var($const6.2, repro.py:9)], kws=(), vararg=None) :: (int32,) -> int32
# del $const6.2
# del $4load_method.1
# i = $8call_method.3 :: int32
# del $8call_method.3
i = cuda.grid(1)
# --- LINE 10 ---
# --- LINE 11 ---
# $14load_global.5 = global(len: <built-in function len>) :: Function(<built-in function len>)
# $18call_function.7 = call $14load_global.5(x, func=$14load_global.5, args=[Var(x, repro.py:9)], kws=(), vararg=None) :: (array(float32, 1d, C),) -> int64
# del $14load_global.5
# $20compare_op.8 = i < $18call_function.7 :: bool
# del $18call_function.7
# bool22 = global(bool: <class 'bool'>) :: Function(<class 'bool'>)
# $22pred = call bool22($20compare_op.8, func=bool22, args=(Var($20compare_op.8, repro.py:11),), kws=(), vararg=None) :: (bool,) -> bool
# del bool22
# del $20compare_op.8
# branch $22pred, 24, 96
if i < len(x):
# --- LINE 12 ---
# label 24
# del $22pred
# $24load_global.0 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
# $26load_method.1 = getattr(value=$24load_global.0, attr=sin) :: Function(<built-in function sin>)
# del $24load_global.0
# $32binary_subscr.4 = getitem(value=x, index=i) :: float32
# $34call_method.5 = call $26load_method.1($32binary_subscr.4, func=$26load_method.1, args=[Var($32binary_subscr.4, repro.py:12)], kws=(), vararg=None) :: (float32,) -> float32
# del $32binary_subscr.4
# del $26load_method.1
# s = $34call_method.5 :: float32
# del $34call_method.5
s = math.sin(x[i])
# --- LINE 13 ---
# $38load_global.6 = global(math: <module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'math' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/math.cpython-38-x86_64-linux-gnu.so'>)
# $40load_method.7 = getattr(value=$38load_global.6, attr=atan2) :: Function(<built-in function atan2>)
# del $38load_global.6
# $46binary_subscr.10 = getitem(value=x, index=i) :: float32
# $48load_global.11 = global(float32: float32) :: class(float32)
# $const50.12 = const(float, 0.5) :: float64
# $52call_function.13 = call $48load_global.11($const50.12, func=$48load_global.11, args=[Var($const50.12, repro.py:13)], kws=(), vararg=None) :: (float64,) -> float32
# del $const50.12
# del $48load_global.11
# $54call_method.14 = call $40load_method.7($46binary_subscr.10, $52call_function.13, func=$40load_method.7, args=[Var($46binary_subscr.10, repro.py:13), Var($52call_function.13, repro.py:13)], kws=(), vararg=None) :: (float32, float32) -> float32
# del $52call_function.13
# del $46binary_subscr.10
# del $40load_method.7
# a = $54call_method.14 :: float32
# del $54call_method.14
a = math.atan2(x[i], float32(0.5))
# --- LINE 14 ---
# $62binary_multiply.17 = s * a :: float32
# del a
# x[i] = $62binary_multiply.17 :: (array(float32, 1d, C), int64, float32) -> none
# del $62binary_multiply.17
x[i] = s * a
# --- LINE 15 ---
# --- LINE 16 ---
# $70load_global.20 = global(cmath: <module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>) :: Module(<module 'cmath' from '/home/gmarkall/miniconda3/envs/numba/lib/python3.8/lib-dynload/cmath.cpython-38-x86_64-linux-gnu.so'>)
# $72load_method.21 = getattr(value=$70load_global.20, attr=exp) :: Function(<built-in function exp>)
# del $70load_global.20
# $78binary_subscr.24 = getitem(value=x, index=i) :: float32
# del x
# $80call_method.25 = call $72load_method.21($78binary_subscr.24, func=$72load_method.21, args=[Var($78binary_subscr.24, repro.py:16)], kws=(), vararg=None) :: (complex64,) -> complex64
# del $78binary_subscr.24
# del $72load_method.21
# e = $80call_method.25 :: complex64
# del $80call_method.25
e = cmath.exp(x[i])
# --- LINE 17 ---
# $88binary_multiply.28 = e * s :: complex64
# del s
# del e
# xc[i] = $88binary_multiply.28 :: (array(complex64, 1d, C), int64, complex64) -> none
# del xc
# del i
# del $88binary_multiply.28
# jump 96
# label 96
# del xc
# del x
# del i
# del $22pred
# $const96.0 = const(NoneType, None) :: none
# $98return_value.1 = cast(value=$const96.0) :: none
# del $const96.0
# return $98return_value.1
xc[i] = e * s