Regarding the function you posted (took me a little while to look into): as well as adding the stride loops, you also need to index the out
array with i
, j
, and k
as well. Also, I find that I get better performance if I use x
for the innermost loop and z
for the outermost loop. So I think your function should look like this:
@cuda.jit
def numba_stride_seg(arr, t1, t2, out):
x, y, z = cuda.grid(3)
stride_x, stride_y, stride_z = cuda.gridsize(3)
for i in range(z, arr.shape[0], stride_z):
for j in range(y, arr.shape[1], stride_y):
for k in range(x, arr.shape[2], stride_x):
value = arr[i, j, k]
if value > t2:
out[i, j, k] = 255
elif value < t1:
out[i, j, k] = 0
else:
out[i, j, k] = 128
A short example that runs both the unstrided and strided versions, comparing their output and performance follows:
from numba import cuda
from numba.cuda.random import (create_xoroshiro128p_states,
xoroshiro128p_uniform_float32)
from time import perf_counter
import numpy as np
import math
@cuda.jit
def random_3d(arr, rng_states):
# Per-dimension thread indices and strides
startx, starty, startz = cuda.grid(3)
stridex, stridey, stridez = cuda.gridsize(3)
# Linearized thread index
tid = (startz * stridey * stridex) + (starty * stridex) + startx
# Use strided loops over the array to assign a random value to each entry
for i in range(startz, arr.shape[0], stridez):
for j in range(starty, arr.shape[1], stridey):
for k in range(startx, arr.shape[2], stridex):
arr[i, j, k] = xoroshiro128p_uniform_float32(rng_states, tid)
@cuda.jit
def numba_seg(arr, t1, t2, out):
x, y, z = cuda.grid(3)
if x < arr.shape[0] and y < arr.shape[1] and z < arr.shape[2]:
value = arr[x, y, z]
if value > t2:
out[x, y, z] = 255
elif value < t1:
out[x, y, z] = 0
else:
out[x, y, z] = 128
@cuda.jit
def numba_stride_seg(arr, t1, t2, out):
x, y, z = cuda.grid(3)
stride_x, stride_y, stride_z = cuda.gridsize(3)
for i in range(z, arr.shape[0], stride_z):
for j in range(y, arr.shape[1], stride_y):
for k in range(x, arr.shape[2], stride_x):
value = arr[i, j, k]
if value > t2:
out[i, j, k] = 255
elif value < t1:
out[i, j, k] = 0
else:
out[i, j, k] = 128
# Array dimensions
X, Y, Z = 701, 900, 719
# Block and grid dimensions
bx, by, bz = 8, 8, 8
# For the one-element-per-thread version
gx, gy, gz = math.ceil(X / bx), math.ceil(Y / by), math.ceil(Z / bz)
# For the stride-loop version
gxs, gys, gzs = 16, 16, 16
# Total number of threads
nthreads = bx * by * bz * gxs * gys * gzs
# Initialize a state for each thread
print("Initializing RNG states")
rng_states = create_xoroshiro128p_states(nthreads, seed=1)
# Generate random numbers
print("Generating random numbers")
arr = cuda.device_array((X, Y, Z), dtype=np.float32)
random_3d[(gxs, gys, gzs), (bx, by, bz)](arr, rng_states)
# Test versions
out = cuda.device_array_like(arr)
out_stride = cuda.device_array_like(arr)
cuda.synchronize()
print("Running one element per thread version")
start = perf_counter()
numba_seg[(gx, gy, gz), (bx, by, bz)](arr, 0.25, 0.75, out)
cuda.synchronize()
end = perf_counter()
one_element_time = end - start
print("Running strided version")
start = perf_counter()
numba_stride_seg[(gxs, gys, gzs), (bx, by, bz)](arr, 0.25, 0.75, out_stride)
cuda.synchronize()
end = perf_counter()
strided_time = end - start
print("Copying results to host")
out_host = out.copy_to_host()
out_stride_host = out_stride.copy_to_host()
print("Sanity checking output")
np.testing.assert_equal(out_host, out_stride_host)
print("Sanity check OK!")
print(f"One element per thread time: {one_element_time}")
print(f"Strided loop time: {strided_time}")
Note that the cuda.synchronize()
calls are not necessary for correctness, I just have them there to ensure that I’m timing the whole kernel execution (and not just the launch). With this benchmark I find that actually the strided version has similar performance (or maybe a few % slower) than the one-thread-per-element version. The output for me is as follows:
Initializing RNG states
Generating random numbers
Running one element per thread version
Running strided version
Copying results to host
Sanity checking output
Sanity check OK!
One element per thread time: 0.14170214999467134
Strided loop time: 0.14913468800659757