In the Numba Gitter I was asked if there was a workaround for Issue #6051, “Cube root intrinsic for Numba” - how can one implement a cube root function that can be called by a CUDA kernel?
It turns out that this makes for a nice example of writing Numba extensions to access libdevice functions - some of them are already made available by Numba (e.g. though functions like math.cos
) but not all the libdevice functions are available. There is a full list of libdevice functions in its documentation.
For implementing a cube root, we’d like to use the __nv_cbrt
and __nv_cbrtf
functions. The example code demonstrates how to do this:
import numpy as np
from numba import cuda
from numba import types
from numba.core.extending import lower_builtin, type_callable
from llvmlite.llvmpy.core import Type
# A Python "reference" function
def cbrt(x):
return x ** (1/3)
# Tell Numba how to type a call to the cbrt function:
# - If it's called with a float32 or float64 argument, then return the same
# type as the argument.
# - Otherwise, return nothing - Numba will not be able to type the function,
# and report an error to the user.
@type_callable(cbrt)
def type_cbrt(context):
def typer(val):
if val in (types.float32, types.float64):
return val
return typer
# See:
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrtf.html#__nv_cbrtf
cbrt_funcs = {
types.float32: '__nv_cbrtf',
types.float64: '__nv_cbrt',
}
# Lowering for the implementation of cbrt - this tells Numba how to generate
# code for the cbrt function. We take the argument passed to the cbrt function,
# generate a call to the __nv_cbrt[f] function, and return its return value.
@lower_builtin(cbrt, types.float32)
@lower_builtin(cbrt, types.float64)
def impl_cbrt(context, builder, sig, args):
ty = sig.return_type
fname = cbrt_funcs[ty]
fty = context.get_value_type(ty)
lmod = builder.module
fnty = Type.function(fty, [fty])
fn = lmod.get_or_insert_function(fnty, name=fname)
return builder.call(fn, args)
# A test kernel that computes the cube root of an array
@cuda.jit
def f(x, r):
i = cuda.grid(1)
if i < len(x):
r[i] = cbrt(x[i])
# Generate some data and launch the kernel to demonstrate its use
arr = np.arange(32).astype(np.float64)
res = np.zeros_like(arr)
print(arr)
f[1, 32](arr, res)
print(res)
# Sanity check - did our kernel match the NumPy cbrt ufunc? Raise if not.
np.testing.assert_allclose(np.cbrt(arr), res)
When run, this outputs:
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.
18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]
[0. 1. 1.25992105 1.44224957 1.58740105 1.70997595
1.81712059 1.91293118 2. 2.08008382 2.15443469 2.22398009
2.28942849 2.35133469 2.41014226 2.46621207 2.5198421 2.57128159
2.62074139 2.66840165 2.71441762 2.75892418 2.80203933 2.84386698
2.88449914 2.92401774 2.96249607 3. 3.03658897 3.07231683
3.10723251 3.14138065]
and no exception is raised.
For creating other Numba extensions that call other libdevice functions, the above code can be used as a starting point for modification.