Random array generation : numba cuda slower than cupy?

Hi all,
I am looking to optimize the random number generation in my Brownian dynamics simulation code.
I quickly turned to GPU computing since my code is highly parallelizable.
For the moment I manage to have an optimal code by generating random numbers with cupy and then using numba to manage the boundary conditions (among other things). I would like to homogenize everything and do the random number generation with numba but I can’t get the same performance. Does anyone have an idea?

import numpy as np
import cupy as cp
from numba import cuda
from numba.cuda.random import create_xoroshiro128p_states, xoroshiro128p_normal_float32
size = (1024, 2, 1000)

# With numpy
%timeit np.random.normal(size=size)
# 59.1 ms ± 1.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# With cupy
%timeit cp.random.normal(size=size, dtype=cp.float32); cp.cuda.Device().synchronize()
# 48.4 µs ± 112 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

# With numba.cuda
@cuda.jit
def numba_cuda_normal(rng_states, out):
    pos = cuda.grid(1)
    for i in range(out.shape[2]):
        out[pos, 0, i] = xoroshiro128p_normal_float32(rng_states, pos)
        out[pos, 1, i] = xoroshiro128p_normal_float32(rng_states, pos)

threads_per_block = 8 # Best performance reach for threads_per_block=8 with RTX3080
blocks = size[0]//threads_per_block
rng_states = create_xoroshiro128p_states(threads_per_block * blocks, seed=1)
out = np.zeros(size, dtype=np.float32)
out_gpu = cuda.to_device(out)
numba_cuda_normal[blocks, threads_per_block](rng_states, out_gpu) # warmup
%timeit numba_cuda_normal[blocks, threads_per_block](rng_states, out_gpu); cuda.synchronize()
# 701 µs ± 869 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Thank you in advance :slight_smile:
G.M

Hi Geoffrey,

In general I’d suggest that a combination of CuPy and Numba is the best route to go - they each have their strengths - in particular, CuPy’s arrays support a lot more NumPy operations than Numba’s device arrays. So I’d tend to recommend using CuPy arrays and array operations, and then use Numba kernels when you need to do something more “custom”.

That said, there are a number of improvements that can be made to the Numba CUDA version here:

  • The distribution of work to threads is less-than ideal - rather than having each thread iterate over two dimensions of the output array, we can use more threads to cover the last index.
  • The number of threads per block is very small - at least two or three warps (64 - 96) threads per block is better for utilization. Having only eight threads per block wastes a significant amount of time as 24 threads of a warp will always be idle.
  • The timing might include synchronization at each iteration - in order to time more carefully, we should launch all iterations then synchronize only once.

The notebook at numba-discourse-815/Discourse 815.ipynb at main · gmarkall/numba-discourse-815 · GitHub improves upon the performance. The changes are:

  • The kernel uses a 2D grid instead of a 1D grid, with the for loop over i distributed across threads instead of being iterated by one thread.
  • 256 threads per block are used, so there are plenty of warps.
  • The grid size calculation is modified to ensure that enough threads are launched for the size of the output array.
  • Timing is done using a loop followed by cuda.synchronize() surrounded by calls to perf_counter(), with the time computed from the start and end times.

The results I get on my machine are:

  • NumPy: 62.8 ms
  • CuPy: 135 µs
  • Original Numba CUDA: 745 µs
  • Modified Numba CUDA: 118 µs

Modified version follows:

# Modified version for numba.cuda

# This kernel maps thread index x to shape[0] and y to shape[2], so that more
# threads can be launched in parallel. The loop over shape[2] is replaced with
# the second thread index
@cuda.jit
def numba_cuda_normal_2(rng_states, out):
    pos, i = cuda.grid(2)
    
    # Ensure our thread is within the bounds of the array
    if pos < out.shape[0] and i < out.shape[2]:
        out[pos, 0, i] = xoroshiro128p_normal_float32(rng_states, pos + out.shape[0] * i)
        out[pos, 1, i] = xoroshiro128p_normal_float32(rng_states, pos + out.shape[0] * i)

# 256 threads / 8 warps per block - a reasonable, slightly arbitrary choice
threads_per_block = (16, 16) 

# Launch enough blocks for all data points.
# Sometimes this will launch slightly more blocks than needed -
# the calculation could be improved slightly
blocks = ((size[0] // threads_per_block[0]) + 1, (size[2] // threads_per_block[1]) + 1)

# RNG state initialization
rng_states = create_xoroshiro128p_states(size[0] * size[2], seed=1)

# Create output array on GPU and warm up JIT
out = np.zeros(size, dtype=np.float32)
out_gpu = cuda.to_device(out)
numba_cuda_normal_2[blocks, threads_per_block](rng_states, out_gpu) # warmup

# How many iterationss to loop through when timing?
N_ITERATIONS = 10000

# Timing: we launch all kernels then synchronize after launching all kernels
# before recording the end timer - this way we avoid including one sync
# per iteration in our timing.

start = perf_counter()

for i in range(N_ITERATIONS):
    numba_cuda_normal_2[blocks, threads_per_block](rng_states, out_gpu)

cuda.synchronize()
end = perf_counter()


# Iteration time is total time dividide by number of iterations
# Time is given in seconds, so multiply to get microseconds
iteration_time = ((end - start) / N_ITERATIONS) * 1_000_000

print(f"Time is {iteration_time} µs per iteration")

Thank you for this very complete answer!
Indeed, in the example I gave, the second dimension of the table is indeed parallelizable.
Surprisingly on my GTX1070 ti, when I run the notebook you created, I get the following results:

  • NumPy: 61.7 ms
  • CuPy: 98.6 µs
  • Original Numba CUDA: 1.13 ms
  • Modified Numba CUDA: 309.13 µs

But if we really want to compare numba results with cupy results, we should put the synchronization step in the loop. In this case:

  • Modified Numba CUDA (in loop sync step): 473.13 µs

In both cases, there is a difference between numba.cuda and cupy. Maybe this is due to the non-use of the cuRAND library?

Anyway, your answer taught me a lot and on my side I realized that what slowed down my program was the data storage step in the device memory between cupy and numba.cuda and not the random number generation itself.

Surprisingly on my GTX1070 ti, when I run the notebook you created, I get the following results:

I have an RTX8000, might have had different versions of CuPy / Numba, etc - a lot of things can influence the times - there might be some modifications to the Numba kernel that would make it perform relatively better on your setup than mine.

But if we really want to compare numba results with cupy results, we should put the synchronization step in the loop.

If you’re trying to understand what the performance would be in the context of a larger program, it doesn’t make sense to synchronize at every iteration - you end up timing the synchronization more than anything else for a microbenchmark like this.

A normal program that invokes many kernels sequentially would do so without a sync, so the best way to understand what performance you’re likely to get in the context of a larger program is to not sync at each iteration.

Out of interest, does your CuPy time change if you explicitly synchronize? (with cupy.cuda.stream.get_current_stream().synchronize())

In both cases, there is a difference between numba.cuda and cupy. Maybe this is due to the non-use of the cuRAND library?

That will be one of the reasons, yes.

I realized that what slowed down my program was the data storage step in the device memory between cupy and numba.cuda and not the random number generation itself.

I’m not sure what you mean by this - you can pass CuPy arrays directly to Numba kernels with no copying, because of the CUDA Array Interface.