Calling functions from a function list/array inside a numba cuda kernel

I have the problem that needs to be addressed by calling huge amounts of different functions, and eventually even combinations of them.

I have quite some experience in writing numba cuda kernels, and I want to solve this problem using this approach, because in the end stage of this project, it needs to handle billions of combinations of function calls.

As an example:

Each thread of the kernel needs to call a different function, depending on the nonce.
(eventually, the nonce will be split with modulus logic into different combination pattern of function calls to reproduce quite long sequential call chains, but this is not part of this post here and I can solve this without any problem).

An example is given here, which is working:
(please ignore what the functions are doing, they’re just dummy examples)

@jit(nopython=True)
def exampleFunc_1(t_in):
    return t_in

@jit(nopython=True)
def exampleFunc_2(t_in):
    return 2*t_in

@cuda.jit
def do_simulation(ret_array):

    tx = cuda.threadIdx.x
    ty = cuda.blockIdx.x
    bw = cuda.blockDim.x
    nonce = ((tx + ty * bw) + 0) & 0xFFFFFFFF

    dummy_value = 11

    if nonce == 0:
        ret_array[nonce] = exampleFunc_1(dummy_value)
    if nonce == 1:
        ret_array[nonce] = exampleFunc_2(dummy_value)

    return 


if __name__ == "__main__":

    threadsperblock = 2 
    blockspergrid = 1

    ret_array = np.zeros((threadsperblock * blockspergrid * 1), dtype=np.float32)
    do_simulation[blockspergrid, threadsperblock](ret_array)
    print(ret_array)

As I don’t need to handle only two functions, but many, the “if nonce == xy then” approach is very cumbersome and will end up in unbelievable long kernel code, and even get more (too) complex when I later need to introduce the combinations of functions in the same thread.
The complexity is so big, that I would have to write python code that automatically creates cuda kernel code for me (will finally be 100s, if not 1000s of code lines with “if/then” statements) with that approach.

So what I really would like to use is something like the following:

@jit(nopython=True)
def exampleFunc_1(t_in):
    return t_in

@jit(nopython=True)
def exampleFunc_2(t_in):
    return 2*t_in

fn_list = [exampleFunc_1, exampleFunc_2]

@cuda.jit
def do_simulation(ret_array):

    tx = cuda.threadIdx.x
    ty = cuda.blockIdx.x
    bw = cuda.blockDim.x
    nonce = ((tx + ty * bw) + 0) & 0xFFFFFFFF

    dummy_value = 11

    # if nonce == 0:
    #     ret_array[nonce] = exampleFunc_1(dummy_value)
    # if nonce == 1:
    #     ret_array[nonce] = exampleFunc_2(dummy_value)

    ret_array[nonce] = fn_list[nonce](dummy_value)

    return 


if __name__ == "__main__":

    threadsperblock = 2 
    blockspergrid = 1

    ret_array = np.zeros((threadsperblock * blockspergrid * 1), dtype=np.float32)
    do_simulation[blockspergrid, threadsperblock](ret_array)
    print(ret_array)


This approach would spare me the unholy experience to write code that writes code approach,
with 100-1000s of “if / then” statements that would be really a horror to write and to debug.
Even if I’m sure this approach would work.

I’ve tried to use the fn_list as global, to hand it over as do_simulation argument,
as list, as np.array (object), even tried to hand over only the function name as string and then use eval inside kernel. But all this is not supported.

Does any expert see here a clever/elegant way how to make this work?
Or another, better approach then writing code that writes my kernel code with 1000ds of “if/then” statements?

Cheers Stoney

Unfortunately I can’t really see a way to make this approach work, as the CUDA target doesn’t really have support for passing functions around and looking them up in this way.

The least-unholy way I can see to write this would be to use some kind of string templating to generate the if / else (or match / case) part of your kernel, iterating over each index and function to generate the individual case and call to the implementing function.

Thanks @gmarkall

I’ve been able to address this in a more or less “clean” way by putting the functions to call into 9 external modules (as I have 9 categories of functions to be combined) inside subfolders.
A “compile” script is then auto-creating a proxy function for each category containing the “if / then / else” statements into the “__init__.py” file of each module.
Like this I can consume the functions inside the kernel with a one liner:

ret_value = fn_proxy(fn_selector, args)

With this approach, my kernel code stays human readable and lean.