CUDA: Explainer of a kernel with 2D blocks, shared memory, atomics

This post is a writeup of a question asked outside of Discourse, which walks through a function line-by-line, and provides some introductory explanation about the grid and indexing, shared memory, and atomic operations.

I was recently asked what the atomic_add3 function does in the Numba testsuite - it is defined in test_atomics.py. The code of this function is:

def atomic_add3(ary):
    tx = cuda.threadIdx.x
    ty = cuda.threadIdx.y
    sm = cuda.shared.array((4, 8), uint32)
    sm[tx, ty] = ary[tx, ty] 
    cuda.syncthreads()
    cuda.atomic.add(sm, (tx, uint64(ty)), 1)
    cuda.syncthreads()
    ary[tx, ty] = sm[tx, ty]

Thread indices

The assignments of tx and ty are:

tx = cuda.threadIdx.x
ty = cuda.threadIdx.y

Kernels are launched with a grid of blocks of threads that are specified at launch time - you will see in test_atomic_add3 the call cuda_atomic_add3[1, (4, 8)](ary) - the subscripting [1, (4, 8)] specifies a grid with one block, and blocks are 4 (x dimension) by 8 (y dimension) threads. Calling cuda.threadIdx.{x,y,z} gives the index in the block of the current thread for the given dimension. So thread (0, 0) will get tx = 0 , ty = 0 , thread (0, 1) will get tx = 0 , ty = 1 , etc. This is idiomatic in CUDA for distributing parallel tasks between threads - the index of the thread is usually used to create indices into arrays so that each thread operates on a different element (this is true for simple embarrassingly parallel kernels - you can do more complex things to enable threads to collaborate on a result, share intermediate results, etc.)

Shared memory declaration

sm = cuda.shared.array((4, 8), uint32)

Shared memory can be thought of as a software-controlled cache on the processor - each Streaming Multiprocessor has a small amount of shared memory (e.g. 32/48/64/96/128K depending on the GPU and current configuration) and each block can use a chunk of it by declaring shared memory. All threads within one block see the same shared memory array - however, different blocks do not share the shared memory (in other words, it is shared between threads within a block, not between blocks). The above line in the function declares a 4x8 array, which can be thought of as mapping 1:1 between threads and shared memory elements (although it needn’t necessarily be used that way).

Load into shared memory

Next, data is loaded from global memory into shared memory:

sm[tx, ty] = ary[tx, ty]

We have 4x8 threads loading 4x8 items into shared memory from global memory using their thread indices - each thread only loads one element

Thread synchronization

After the load, we need to synchronize threads within the block:

cuda.syncthreads()

Writes to shared memory by a thread are immediately visible by that same thread. However, writes by a thread A are not immediately visible by another thread B - there will be several cycles until this is the case. Calling syncthreads() places a barrier that can be passed when writes by all threads in the block are visible to all other threads in the block. (there is also syncwarp() for if you only need writes to be visible between threads within a warp rather than within the whole block).

The general shared memory strategy

In the atomic_add3 kernel, the syncthreads() call seems to not technically be necessary, because it looks as if threads only ever read/write the (tx, ty) element of sm , so it looks a little odd. But in general the pattern for using shared memory, which the function appears to follow, is:

<declare shared memory>
<threads cooperate to stage data into shared memory>
cuda.syncthreads()
<threads cooperate on computations on data in shared memory>
cuda.syncthreads()
<threads cooperate to write the results back to global memory>

The staging into or out of data in shared memory could be omitted, too - in some cases it may only make sense to load input into shared memory but write results out to global memory, or to work on intermediate results in shared memory using input from global memory.

Atomic addition

Now threads add to their element in shared memory:

cuda.atomic.add(sm, (tx, uint64(ty)), 1)

the atomic add is to the array sm , the element (tx, uint64(ty)) , and adds the value 1 . It’s unclear to me why ty is cast to uint64 in this test. Technically the atomic wouldn’t be required in this kernel, since each element is written by one thread, but as this is a test of atomics, it at least tests something about the code generation for atomics - though I think an issue with this test is that it would still pass if atomic addition generated the code for a “normal” addition.

Thread synchronization (after computation in shared memory)

Having finished computations on the result in shared memory, we have another syncthreads() as per the “common pattern for using shared memory” described above:

cuda.syncthreads()

Writing results back to global memory

Finally, threads cooperate to write the result back to global memory:

ary[tx, ty] = sm[tx, ty]

Questions

Please do respond to this post if you have any questions or would like further explanations of this or other code!

3 Likes

@gmarkall thanks for sharing this stuff. Apropos of atomics: I was wondering how to use atomic operations with a cpu target instead of cuda. The only reference I found is this gist from @sklam https://gist.github.com/sklam/40f25167351832fe55b64232785d036d there’s nothing in the official docs…

I’m not as familiar with the CPU target, but I’d guess that if there are no atomics supported in it, then you’d have to implement an extension using the low-level extension API to build atomicrmw instructions. It looks like llvm_call doesn’t exist in Numba anymore, so the example from @sklam would need adapting to work with current Numba.

For some atomic operations, the atomicrmw instruction could be used - these can be constructed with the IR builder atomic_rmw method.