CUDA: Building extensions to call libdevice functions - cube root example

In the Numba Gitter I was asked if there was a workaround for Issue #6051, “Cube root intrinsic for Numba” - how can one implement a cube root function that can be called by a CUDA kernel?

It turns out that this makes for a nice example of writing Numba extensions to access libdevice functions - some of them are already made available by Numba (e.g. though functions like math.cos) but not all the libdevice functions are available. There is a full list of libdevice functions in its documentation.

For implementing a cube root, we’d like to use the __nv_cbrt and __nv_cbrtf functions. The example code demonstrates how to do this:

import numpy as np

from numba import cuda

from numba import types
from numba.core.extending import lower_builtin, type_callable

from llvmlite.llvmpy.core import Type


# A Python "reference" function
def cbrt(x):
    return x ** (1/3)


# Tell Numba how to type a call to the cbrt function:
# - If it's called with a float32 or float64 argument, then return the same
#   type as the argument.
# - Otherwise, return nothing - Numba will not be able to type the function,
#   and report an error to the user.
@type_callable(cbrt)
def type_cbrt(context):
    def typer(val):
        if val in (types.float32, types.float64):
            return val
    return typer


# See:
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrtf.html#__nv_cbrtf
cbrt_funcs = {
    types.float32: '__nv_cbrtf',
    types.float64: '__nv_cbrt',
}


# Lowering for the implementation of cbrt - this tells Numba how to generate
# code for the cbrt function. We take the argument passed to the cbrt function,
# generate a call to the __nv_cbrt[f] function, and return its return value.
@lower_builtin(cbrt, types.float32)
@lower_builtin(cbrt, types.float64)
def impl_cbrt(context, builder, sig, args):
    ty = sig.return_type
    fname = cbrt_funcs[ty]
    fty = context.get_value_type(ty)
    lmod = builder.module
    fnty = Type.function(fty, [fty])
    fn = lmod.get_or_insert_function(fnty, name=fname)
    return builder.call(fn, args)


# A test kernel that computes the cube root of an array
@cuda.jit
def f(x, r):
    i = cuda.grid(1)
    if i < len(x):
        r[i] = cbrt(x[i])


# Generate some data and launch the kernel to demonstrate its use
arr = np.arange(32).astype(np.float64)
res = np.zeros_like(arr)
print(arr)

f[1, 32](arr, res)

print(res)

# Sanity check - did our kernel match the NumPy cbrt ufunc? Raise if not.
np.testing.assert_allclose(np.cbrt(arr), res)

When run, this outputs:

[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.]
[0.         1.         1.25992105 1.44224957 1.58740105 1.70997595
 1.81712059 1.91293118 2.         2.08008382 2.15443469 2.22398009
 2.28942849 2.35133469 2.41014226 2.46621207 2.5198421  2.57128159
 2.62074139 2.66840165 2.71441762 2.75892418 2.80203933 2.84386698
 2.88449914 2.92401774 2.96249607 3.         3.03658897 3.07231683
 3.10723251 3.14138065]

and no exception is raised.

For creating other Numba extensions that call other libdevice functions, the above code can be used as a starting point for modification.

1 Like