CUDA how to run concurrently kernels using multiprocessing?

Hi!
Anyone can suggest on how to run concurrently async kernels using multiprocessing ?

Based on CUDA docs it’s says >A kernel from one CUDA context cannot execute concurrently with a kernel from another CUDA context.

Seems I need to import the same CUDA context to all processes, but I really stuck…

Any help much appreciated.

This sounds like quite out of date documentation - is this in the Numba documentation?

Kernels can run concurrently in different streams in the same context, or in different contexts - this was a limitation only in very early versions of CUDA.

https:// docs . nvidia . com/cuda/cuda-c-programming-guide/index.html#asynchronous-concurrent-execution

My CC 8.6 but no matter what do I do I can’t make them run concurrently, so I thought different contexts are the main bottleneck.

That does surprise me - I’m enquiring as to the interpretation of this sentence and will get back to you about that.

In the meantime can you post an example code illustrating the pattern you’re using to try to execute kernels concurrently, so we can see what might be blocking concurrent execution?

OK, it looks like I had a long standing misunderstanding about the nature of concurrent execution from different processes - the kernels won’t overlap from multiple contexts / processes, and only be interleaved.

Can your application use CUDA streams to overlap kernel execution within one process instead?

Unfortunately no… And we are talking about AMPER architecture which is obviously been made to process really huge amount of data. Single process leads to CPU bottleneck… unfortunately.

Can you suggest any workaround ?

Perhaps it may work if I find the way how to fork the same context to all sub processes.

You may be able to use MPS for this: Multi-Process Service :: GPU Deployment and Management Documentation

I find MPS quite buggy, it’s often freezes threads for no reason and there is no easy way to debug it.
Export/import gpu context directly in numba can really help and avoid MPS use.

For another workaround, can you write your code such that it uses threads instead of processes? (It’s a bit hard to really give good concrete suggestions with such a general description of your application)

@gmarkall

I switched to threads, however it’s clear that none of threads can submit kernel into own stream unless there is already running kernel… So no concurrency at all.
Please take a look on super basic code:

from numba import cuda
import threading

@cuda.jit(cache=True)
def longrun_kernel(arr):
    i = 0
    for i in range(999999999):
        arr[0][0] += 1.


def worker(thread_id):
    stream = cuda.stream()
    arr = cuda.device_array((10000, 10), stream=stream)
    longrun_kernel[16, 16, stream](arr)
    print('KERNEL SUBMITTED: %s' % thread_id)


threads = []
for i in range(4):
    t = threading.Thread(target=worker, args=(i,))
    t.start()
    threads.append(t)
#
for t in threads:
    t.join()

I played with NUMBA_CUDA_ARRAY_INTERFACE_SYNC and CUDA_PYTHON_CUDA_PER_THREAD_DEFAULT_STREAM
but it didn’t help at all.

Please suggest.

(For the benefit of other readers, the code in @gizmo’s previous post gives a traceback that culminates in numba.cuda.cudadrv.driver.CudaAPIError: [3] Call to cuCtxGetCurrent results in CUDA_ERROR_NOT_INITIALIZED)

It seems that there’s an implicit requirement in Numba to initialize CUDA in the main thread first (this is probably a bug in Numba, but probably hasn’t been spotted because most code will likely do something with CUDA on the main thread first before doing something from another thread). If I create a device array on the main thread first to force initialization of CUDA, then the following code completes successfully:

from numba import cuda
import threading

# Example output:
#
# KERNEL SUBMITTED: 0
# KERNEL FINISHED: 0
# KERNEL SUBMITTED: 3
# KERNEL SUBMITTED: 2
# KERNEL SUBMITTED: 1
# KERNEL FINISHED: 3
# KERNEL FINISHED: 2
# KERNEL FINISHED: 1


@cuda.jit(cache=True)
def longrun_kernel(arr):
    for i in range(99999999):
        arr[0][0] += 1.


def worker(thread_id):
    stream = cuda.stream()
    arr = cuda.device_array((10000, 10), stream=stream)
    longrun_kernel[16, 16, stream](arr)
    print('KERNEL SUBMITTED: %s' % thread_id)
    stream.synchronize()
    print('KERNEL FINISHED: %s' % thread_id)


# Force initialization of CUDA before using threads
x = cuda.device_array(1)

threads = []
for i in range(4):
    t = threading.Thread(target=worker, args=(i,))
    t.start()
    threads.append(t)


for t in threads:
    t.join()

I do find that the interleaving is not always perfect (see example output) and I’m not sure to what extent this is due to the granularity of CPU scheduling - you might find the interleaving works better in an actual application with a greater workload.

I also tried doing the same with streams in a single thread, which is easier to ensure interleaving:

from numba import cuda

# Example output:
#
# KERNEL SUBMITTED: <CUDA stream 94190282511440 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL SUBMITTED: <CUDA stream 94190281519744 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL SUBMITTED: <CUDA stream 94190280466320 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL SUBMITTED: <CUDA stream 94190283213168 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL FINISHED: <CUDA stream 94190282511440 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL FINISHED: <CUDA stream 94190281519744 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL FINISHED: <CUDA stream 94190280466320 on <CUDA context c_void_p(94190276040272) of device 0>>
# KERNEL FINISHED: <CUDA stream 94190283213168 on <CUDA context c_void_p(94190276040272) of device 0>>


@cuda.jit(cache=True)
def longrun_kernel():
    cuda.nanosleep(100000000000000)


def done_callback(stream, status, arg):
    print('KERNEL FINISHED: %s' % stream)


N = 4
streams = [cuda.stream() for _ in range(N)]
arrays = [cuda.device_array((1, 1)) for _ in range(N)]

for stream, array in zip(streams, arrays):
    longrun_kernel[1, 1, stream]()
    # Call the `done_callback()` function once the kernel in this stream
    # is finished
    stream.add_callback(done_callback, None)
    print('KERNEL SUBMITTED: %s' % stream)

cuda.synchronize()

If you can distribute your CPU work across threads and use a single thread with multiple streams for launching GPU work, that might give better control over interleaving.