It is possible to have “const” arguments to kernels - since each kernel with a const argument contains distinct code (based on the const values) the mechanism for doing this involves defining a function specialized for the required const arguments. The const values are captured from the environment, usually using a closure to define the kernel. For example:
from numba import cuda
import numpy as np
# Data size - something reasonable for local array capacity
N = 10
def generate_kernel(index):
"""Generate a kernel specialized for a given index value.
The N and index values are hardcoded into the kernel as constants, as their
values are captured from the scope outside the kernel."""
@cuda.jit
def kernel(data):
local_data = cuda.local.array(N, dtype=np.int32)
# Copy all data into local array
for i in range(N):
local_data[i] = data[i]
# Compute on all elements of local array
for i in range(N):
local_data[i] = local_data[i] * 2
# Print out only the data for the given index. Only this entry of the
# local array was live - all others are dead
print(local_data[index])
return kernel
# Generate two different instances of the kernel and launch them with some data
index_2_kernel = generate_kernel(2)
index_5_kernel = generate_kernel(5)
data = np.arange(N)
index_2_kernel[1, 1](data)
index_5_kernel[1, 1](data)
# Synchronization is needed here because launches are asynchronous - the
# program can exit before the kernel print buffer is flushed, if we don't wait
# for completion of all device operations.
cuda.synchronize()
When this is executed with
NUMBA_DUMP_ASSEMBLY=1 python repro.py
we can see the two variants were optimized for the const values at the PTX level - the loads to local data and the computations for the unaccessed values have been eliminated, and the load and arithmetic for the live value (element 2 or 5 depending on which version is launched) are indeed realized in registers. For the value 2, we see:
ld.global.u32 %r1, [%rd2+16];
shl.b32 %r2, %r1, 1;
(i.e. loading the value at index 2: (2 * 8) = 16 bytes offset, and the shift left to multiply by 2. Similarly for the const index 5 we get:
ld.global.u32 %r1, [%rd2+40];
shl.b32 %r2, %r1, 1;
Loading from (5 * 8) = 40 bytes offset from the base of the array.
You should also be able to observe that there is no longer a reference to local memory, or a loop copying data and computing on unused values.