CUDA device function pointers? Re-implementing `scipy.integrate.solve_ivp` for `numba.cuda`

Synopsis

I am considering to port/reimplement scipy.integrate.solve_ivp specifically for the DOP853 method to numba-compiled code for CUDA - in such a way that multiple solvers for similar-ish parameters can run in parallel. Just to be clear: I am not even sure this is a good idea to begin with. I suspect that I might run into stack size limits. I’d love to do this using a function decorated by guvectorize as an entry point, but that’s not a critical requirement. More critically, a user of my code must be able to pass more or less arbitrary functions as arguments to my solver. That’s the main problem, besides a few others.

Context

The idea is a result of my work on my NumFOCUS small development grant project “array types for scaling for poliastro”. The original developers of poliastro decided to archive the project in October 2023, unfortunately. In December 2023, I forked it under the name “hapsira” to continue my efforts.

One of the ideas of my work is to allow the user to decide if the package should do its work on the CPU (optionally in parallel) or on a GPU, via a setting / environment variable. Conceptually, I want all code to compile unaltered for both targets. The current mechanisms enabling this madness can be found here. So far, it’s working pretty well.

The package features a number of orbit propagation algorithms with the most complicated but also most interesting one being cowell’s method. It is the last and by far most complicated bit of code left on my list for “array-ifying” this package. The algorithm allows to programmatically include perturbations, i.e. effects like atmospheric drag. It is essentially built around solve_ivp. Imagine you want to “predict”/“forecast” the orbits for a cloud of space debris, you would need to run solve_ivp once per piece of debris (easily hundreds of thousands of pieces). Imagine a user wants to alter the drag effect equations (and/or other effects) - this is what the current implementation easily allows by accepting function references/pointers.

scipy.integrate.solve_ivp is infamous for being awfully slow. I am absolutely not the first poor soul interested in speeding it up. Just google for any combination of solve_ivp and keywords like “numba”, “cuda”, “cupy”, “accelerate”, etc … Besides, poliastro’s developers investigated the adoption of numbakit-ode. numbakit-ode indeed provides a much faster solver accelerated by numba, but conceptually it would also need a ton of re-work to make it compile for CUDA.

Problem: User-provided callback functions

solve_ivp is built around the idea that a user can pass callback functions as arguments. Not just one - but actually a “list” of them (“events”) that can, for additional fun, potentially change some internal states of theirs on each call.

Function pointers in CUDA appear to be a thing, though I have never used them in C/C++. numba appears to have limitations, see #3983, though I am not sure I understand them correctly. It appears that I can at least pass a numba-compiled device function as an argument to another numba-compiled device function. I have yet not figured out a way to pass a device function as an argument to a function decorated by guvectorize (or vectorize) or directly to a CUDA kernel function (compiled via numba).

If I understand the subject correctly, CUDA device functions are inline-compiled by default. In other words, the function “passed by the user” must be defined before any other function calling it gets compiled. This can be done, but it is causing some funky organizational issues …

Experiment 1, inline-compiled callbacks, successful

Just checking if device function can be passed around as pointers - it works, in principle:

import numpy as np
from numba import cuda

@cuda.jit("f8(f8)", device = True, inline = True)
def foo(x):
    return x * 20

@cuda.jit("f8(f8)", device = True, inline = True)
def bar(x):
    return x * 10

@cuda.jit(device = True, inline = True)  # TODO signature???
def wrapper(x, func):
    return func(x)

@nb.vectorize("f8(f8)", target = "cuda")
def demo(x):
    return wrapper(x, foo if x < 15 else bar)

y = np.arange(10., 20.)
print(demo(y))

One issue though: I have no idea how to type-annotate wrapper. Is this even possible with numba?

Experiment 2, dynamic approach, failed

My initial idea was to have something like a list or dict that a user can append callbacks to before the dispatcher function gets compiled:

_callbacks = []  # any type that works? 

@cuda.jit("f8(f8)", device = True)
def user_callback(x):  # provided by user, probably in another module
    return x ** 2

_callbacks.append(user_callback)  # "register" callback

@cuda.jit("f8(f8,i8)", device = True)
def dispatcher(x, handle):
    return _callbacks[handle](x)  # refer to callback by handle, an integer

@cuda.jit("void8(f8)", device = True)
def solver():
    return dispatcher(2.0, 0)

Tuples are the only thing the CUDA compiler pipeline picks up, so this is really a no-go. Unless the tuple is rewritten with some funky Python interpreter magic each time a new callback gets registered. Not happy about that. The next best thing would be to rewrite dispatcher directly … so on to experiment 3 …

Experiment 3, old-school templating approach, successful

The module’s core.py, exposing a main function acting as my future solver:

import numba as nb
from numba import cuda

assert cuda.is_available()

from plugin import register, get_dispatcher

_ = register(r"""
def demo_func_lib(x):
    return x * 2
""")

exec(get_dispatcher())

@nb.vectorize("f8(f8,i8)", target="cuda")
def main(x, func):
    return dispatcher(x, func)

The plugin.py mechanism. Could be based on jinja

from typing import Callable, Union
from inspect import getsource

_header = r"""
@cuda.jit("f8(f8,i8)", inline=True, device=True)
def dispatcher(x, func):
""".strip("\n")
_entry = r"""
    {prefix}if func == {handle:d}:
        return {name:s}(x)
""".strip("\n")
_footer = """
    else:
        raise ValueError("no known func")
""".strip("\n")
_jit = r'@cuda.jit("f8(f8)", inline=True, device=True)'
_plugins = {}

def register(src: Union[str, Callable]) -> int:
    if callable(src):
        src = getsource(src)
    src = src.strip("\n")
    name = src[4:].split("(")[0]
    assert name not in _plugins.keys()
    handle = len(_plugins)
    _plugins[name] = dict(handle = handle, src = src)
    return handle

def get_handle(name: str) -> int:
    return _plugins[name]["handle"]

def get_dispatcher() -> str:
    srcs = []
    for plugin in _plugins.values():
        srcs.append(f'{_jit:s}\n{plugin["src"]:s}\n')
    srcs.append(_header)
    for name, plugin in _plugins.items():
        srcs.append(_entry.format(
            prefix = "el" if plugin["handle"] != 0 else "",
            handle = plugin["handle"],
            name = name,
        ).strip("\n"))
    srcs.append(_footer)
    out = "\n".join(srcs)
    print(out)  # debug
    return out

And last but not least what the user would need to do:

import numpy as np

from plugin import register, get_handle

demo_handle_1 = register(r"""
def param_func_1(x):
    return x * 3
""")

def param_func_2(x):
    return x * 4

demo_handle_2 = register(param_func_2)

from core import main

y = np.arange(10., 20.)
print(main(y, get_handle("demo_func_lib")))
print(main(y, demo_handle_1))
print(main(y, demo_handle_2))

Steps:

  1. Import the “plugin” mechanism
  2. Register the callbacks - either as a string containing source code OR as a function from which source code is extracted
  3. Import the actual “main” implementation from “core”, which triggers the compilation of the CUDA kernel including the user’s code.

Summary

What is your advise? What am I overlooking (if anything)? What alternative approaches do exist?

Three more experiments and my likely solution. Imports first:

from typing import Callable, List

import numba as nb
from numba import cuda
import numpy as np

One custom compiled kernel per function that is provided as an argument. For a kernel as large as mine is going to be, this going to end up being slow, but it works:

def factory(func: Callable, y: float) -> Callable:

    @cuda.jit("f8(f8)", device = True, inline = True)
    def wrapper(x):
        return func(x) * 2

    @nb.vectorize("f8(f8)", target = "cuda", nopython = True)
    def prototype(x):
        return wrapper(x)

    return prototype

@cuda.jit("f8(f8)", device = True, inline = True)
def foo(x):
    return x + 1

@cuda.jit("f8(f8)", device = True, inline = True)
def bar(x):
    return x + 3

Second idea, a slight variation on the first, trying to pass the provided function around, also works:

def factory(func: Callable, y: float) -> Callable:

    @cuda.jit(device = True, inline = True)
    def demo(func_, x):
        return func_(x) * 2

    @nb.vectorize("f8(f8)", target = "cuda", nopython = True)
    def prototype(x):
        return demo(func, x)

    return prototype

@cuda.jit("f8(f8)", device = True, inline = True)
def foo(x):
    return x + 1

@cuda.jit("f8(f8)", device = True, inline = True)
def bar(x):
    return x + 3

Here, the only issue is that I can not annotate demo. For numba.jit on the CPU, I could say "f8(FunctionType(f8(f8)),f8)", but this notation fails for cuda: KeyError: "Failed in cuda mode pipeline (step: native lowering)\n<class 'numba.core.types.function_type.FunctionType'>". Is this a bug or am I missing something?

Moving demo out of the factory also works:

@cuda.jit(device = True, inline = True)  # "f8(FunctionType(f8(f8)),f8)"
def demo(func_, x):
    return func_(x) * 2

def factory(func: Callable, y: float) -> Callable:

    @nb.vectorize("f8(f8)", target = "cuda", nopython = True)
    def prototype(x):
        return demo(func, x)

    return prototype

Test ideas with:

foo_vec = factory(foo, 2)
bar_vec = factory(bar, 4)

X = np.arange(10., 20.)

print(foo_vec(X)) # first round
print(bar_vec(X))

print(foo_vec(X)) # second round, just to be sure that both functions are independent
print(bar_vec(X))

A full implementation draft for solve_ivp and DOP853 (currently only compiling for targets cpu and parallel) can be found here.