Call a `jitclass` method from C/C++?

Hi!
Really happy to use this great project, thanks for that!

My use case is about calling a jitclass (non static) method from C/C++. I’m using pybind11 to create extensions of a C++ lib in Python and would like to rely on numba to make such extensions much more efficient by jitting those parts.

I managed quite easily to use a cfunc to call a free function from C++, but I can’t figure out if it’s possible or not to do the same on a class method with the self arg already managed/bound to the C struct counterpart of the jitclass. I would like to avoid using Python interpreter for performance (incl. acquiring the gil).
Does it looks feasible to get pointers to both C struct and class method and make the call from C/C++ code? Anybody has a pointer of where something similar would be done in numbas codebase?

I feel a bit lost in my attempts to understand how it works and how to do such a thing, sorry if the description is not crystal clear…
Thanks!

Adrien

1 Like

There’s probably a path forward for you… have a look at this as a starting point
I think you’ll want to make standalone entry functions on the numba, but you can use global StructRefs as jitclass doesn’t cache well.

Hi @nelson2005 thanks a lot for the quick answer!
This is exactly the article I used for the cfunc approach, I don’t figure out how to extend this to class methods…
I’ll take a look at the other link! So it looks like composing with cfunc and StructRefs would be a better path than using jitclass?

‘Better’ is relative :slight_smile: but it’s the approach I’ve taken, calling methods of a global StructRef (or one in a container if multiple are needed) from a cfunc (or a njit-func called from a cfunc, njit-funcs and cfuncs have some subtle differences)
This kind of stuff has a bit of a learning curve, so it’s best to make sure to get the foundation solid before moving on too much. Particularly around error handling/exceptions.

Hi!

I finally managed to make it work for my use case (with a lot of time reading the numba source code and inspecting LLVM IR!).
The main idea is to build a jitclass as a dataclass, and handle functions separately rather than relying on the default machinery of jitclass for class methods.

To do so, it requires to get a cfunc for these functions knowing the signature from the jitclass one:

def get_cfunc_ptr(fn, signature):
    jitted = nb.njit(fn)
    compiled_fn = jitted.get_compile_result(signature)

    return compiled_fn.library.get_pointer_to_function(
        compiled_fn.fndesc.llvm_cfunc_wrapper_name
    )

Then to pass the jitclass in C++ world as a pack of 2 pointers, the meminfo and the data pointers:

struct NumbaJitclass
{
    void* meminfo_ptr;
    void* data_ptr;
};

The function taking a jitclass as a single argument and not returning will have the following signature: void (*)(NumbaJitclass).
The meminfo and data pointers of the jitclass are retrieved using:

from numba.experimental.jitclass import _box

meminfo = _box.box_get_meminfoptr(d)
data = _box.box_get_dataptr(d)

If one also need to manage data C++ side, without calling into CPython, it can be achieved by jitting a very simple function returning a new instance of the jitclass, and pass a cfunc pointer to this functions to be cast as NumbaJitclass(*)().

@nb.experimental.jitclass(spec=[("a", nb.float64[::1])])
class TestData:
    def __init__(self):
        self.a = np.zeros(10)

def create():
    return TestData()

Don’t forget to delete the data: to decrement the NRT reference counter of the jitclass one can use:

delete_data_fn_ptr = nb.core.runtime.nrt.rtsys.library.get_pointer_to_function(
    "NRT_decref"
)

All of this can be done without the GIL.

Here is a full but still very simplified example.

  • it requires a C++ compiler, cmake (+ your favorite generator), pybind11 and python.
  • it has few prints to help understanding what’s going on

You only need:

  • CMakeLists.txt
    cmake_minimum_required(VERSION 3.28)
    
    project(test)
    
    set(CMAKE_CXX_STANDARD 11)
    
    find_package(pybind11 CONFIG REQUIRED)
    
    pybind11_add_module(testpy
        main.cpp
    )
    install(TARGETS testpy DESTINATION ${CMAKE_SOURCE_DIR})
    
  • main.cpp
    #include <pybind11/pybind11.h>
    
    #include <iostream>
    
    namespace py = pybind11;
    
    struct NumbaJitclass
    {
        void* meminfo_ptr;
        void* data_ptr;
    };
    
    struct PyNumbaJitclass
    {
        std::uintptr_t meminfo_ptr;
        std::uintptr_t data_ptr;
    };
    
    struct ArrayModel
    {
        void* meminfo;
        void* parent;
        std::int64_t nitems;
        std::int64_t itemsize;
        double* val;
    };
    
    PYBIND11_MODULE(testpy, m)
    {
        py::class_<PyNumbaJitclass>(m, "JitClass")
            .def(py::init<>())
            .def_readwrite("meminfo", &PyNumbaJitclass::meminfo_ptr)
            .def_readwrite("data", &PyNumbaJitclass::data_ptr);
    
        m.def("call_fn_on_existing_data",
              [](std::uintptr_t fn_ptr, PyNumbaJitclass& py_data) -> int
              {
                  // Release Python GIL
                  py::gil_scoped_release release;
    
                  auto fn = (void (*)(NumbaJitclass)) fn_ptr;
                  auto data = reinterpret_cast<NumbaJitclass&>(py_data);
    
                  std::cout << "nitems: " << ((ArrayModel*) data.data_ptr)->nitems << std::endl;
                  std::cout << "itemsize (bytes): " << ((ArrayModel*) data.data_ptr)->itemsize
                            << std::endl;
    
                  std::cout << "first value before call: " << *((ArrayModel*) data.data_ptr)->val
                            << std::endl;
                  fn(data);
                  std::cout << "first value after call: " << *((ArrayModel*) data.data_ptr)->val
                            << std::endl;
    
                  return 0;
              });
    
        m.def("call_fn_on_created_data",
              [](std::uintptr_t create_data_fn, std::uintptr_t delete_data_fn, std::uintptr_t fn) -> int
              {
                  // Release Python GIL
                  py::gil_scoped_release release;
    
                  auto create_data = (NumbaJitclass(*)()) create_data_fn;
                  auto delete_data = (void (*)(NumbaJitclass)) delete_data_fn;
                  auto func = (void (*)(NumbaJitclass)) fn;
    
                  auto data = create_data();
    
                  std::cout << "nitems: " << ((ArrayModel*) data.data_ptr)->nitems << std::endl;
                  std::cout << "itemsize (bytes): " << ((ArrayModel*) data.data_ptr)->itemsize
                            << std::endl;
    
                  std::cout << "first value before call: " << *((ArrayModel*) data.data_ptr)->val
                            << std::endl;
                  func(data);
                  std::cout << "first value after call: " << *((ArrayModel*) data.data_ptr)->val
                            << std::endl;
                  delete_data(data);
    
                  return 0;
              });
    }
    
  • test.py
    import numpy as np
    import numba as nb
    from numba.experimental.jitclass import _box
    
    from testpy import call_fn_on_existing_data, call_fn_on_created_data, JitClass
    
    
    @nb.experimental.jitclass(spec=[("a", nb.float64[::1])])
    class TestData:
        def __init__(self):
            self.a = np.zeros(10)
    
    
    def increment(data):
        print(data.a[0])
        data.a = np.ones(10)
        print(data.a[0])
    
        return 0
    
    
    def create_data():
        return TestData()
    
    
    def get_cfunc(fn, signature):
        jitted = nb.njit(fn)
        compiled_fn = jitted.get_compile_result(signature)
    
        return compiled_fn.library.get_pointer_to_function(
            compiled_fn.fndesc.llvm_cfunc_wrapper_name
        )
    
    
    compiled_fn_ptr = get_cfunc(
        increment,
        nb.core.typing.Signature(
            None,
            (TestData.class_type.instance_type,),
            None,
        ),
    )
    
    d = TestData()
    
    jitclass = JitClass()
    jitclass.meminfo = _box.box_get_meminfoptr(d)
    jitclass.data = _box.box_get_dataptr(d)
    
    call_fn_on_existing_data(compiled_fn_ptr, jitclass)
    print("jitclass instance, 'a' member:", d.a)
    
    create_data_fn_ptr = get_cfunc(
        create_data,
        nb.core.typing.Signature(
            TestData.class_type.instance_type,
            (),
            None,
        ),
    )
    
    delete_data_fn_ptr = nb.core.runtime.nrt.rtsys.library.get_pointer_to_function(
        "NRT_decref"
    )
    
    call_fn_on_created_data(create_data_fn_ptr, delete_data_fn_ptr, compiled_fn_ptr)
    

Cheers

1 Like

It’s really a wonderful scheme! I think it should be integraed to the Numba reference!
By the way, I also find that in this line:

compiled_fn_ptr = get_cfunc(
    increment,
    nb.core.typing.Signature(
        None,
        (TestData.class_type.instance_type,),
        None,
    ),
)

getting the cfunc of the mthod function of the class is also OK.