Hi!
I finally managed to make it work for my use case (with a lot of time reading the numba source code and inspecting LLVM IR!).
The main idea is to build a jitclass as a dataclass, and handle functions separately rather than relying on the default machinery of jitclass for class methods.
To do so, it requires to get a cfunc for these functions knowing the signature from the jitclass one:
def get_cfunc_ptr(fn, signature):
jitted = nb.njit(fn)
compiled_fn = jitted.get_compile_result(signature)
return compiled_fn.library.get_pointer_to_function(
compiled_fn.fndesc.llvm_cfunc_wrapper_name
)
Then to pass the jitclass in C++ world as a pack of 2 pointers, the meminfo and the data pointers:
struct NumbaJitclass
{
void* meminfo_ptr;
void* data_ptr;
};
The function taking a jitclass as a single argument and not returning will have the following signature: void (*)(NumbaJitclass).
The meminfo and data pointers of the jitclass are retrieved using:
from numba.experimental.jitclass import _box
meminfo = _box.box_get_meminfoptr(d)
data = _box.box_get_dataptr(d)
If one also need to manage data C++ side, without calling into CPython, it can be achieved by jitting a very simple function returning a new instance of the jitclass, and pass a cfunc pointer to this functions to be cast as NumbaJitclass(*)().
@nb.experimental.jitclass(spec=[("a", nb.float64[::1])])
class TestData:
def __init__(self):
self.a = np.zeros(10)
def create():
return TestData()
Don’t forget to delete the data: to decrement the NRT reference counter of the jitclass one can use:
delete_data_fn_ptr = nb.core.runtime.nrt.rtsys.library.get_pointer_to_function(
"NRT_decref"
)
All of this can be done without the GIL.
Here is a full but still very simplified example.
- it requires a C++ compiler,
cmake (+ your favorite generator), pybind11 and python.
- it has few prints to help understanding what’s going on
You only need:
CMakeLists.txtcmake_minimum_required(VERSION 3.28)
project(test)
set(CMAKE_CXX_STANDARD 11)
find_package(pybind11 CONFIG REQUIRED)
pybind11_add_module(testpy
main.cpp
)
install(TARGETS testpy DESTINATION ${CMAKE_SOURCE_DIR})
main.cpp#include <pybind11/pybind11.h>
#include <iostream>
namespace py = pybind11;
struct NumbaJitclass
{
void* meminfo_ptr;
void* data_ptr;
};
struct PyNumbaJitclass
{
std::uintptr_t meminfo_ptr;
std::uintptr_t data_ptr;
};
struct ArrayModel
{
void* meminfo;
void* parent;
std::int64_t nitems;
std::int64_t itemsize;
double* val;
};
PYBIND11_MODULE(testpy, m)
{
py::class_<PyNumbaJitclass>(m, "JitClass")
.def(py::init<>())
.def_readwrite("meminfo", &PyNumbaJitclass::meminfo_ptr)
.def_readwrite("data", &PyNumbaJitclass::data_ptr);
m.def("call_fn_on_existing_data",
[](std::uintptr_t fn_ptr, PyNumbaJitclass& py_data) -> int
{
// Release Python GIL
py::gil_scoped_release release;
auto fn = (void (*)(NumbaJitclass)) fn_ptr;
auto data = reinterpret_cast<NumbaJitclass&>(py_data);
std::cout << "nitems: " << ((ArrayModel*) data.data_ptr)->nitems << std::endl;
std::cout << "itemsize (bytes): " << ((ArrayModel*) data.data_ptr)->itemsize
<< std::endl;
std::cout << "first value before call: " << *((ArrayModel*) data.data_ptr)->val
<< std::endl;
fn(data);
std::cout << "first value after call: " << *((ArrayModel*) data.data_ptr)->val
<< std::endl;
return 0;
});
m.def("call_fn_on_created_data",
[](std::uintptr_t create_data_fn, std::uintptr_t delete_data_fn, std::uintptr_t fn) -> int
{
// Release Python GIL
py::gil_scoped_release release;
auto create_data = (NumbaJitclass(*)()) create_data_fn;
auto delete_data = (void (*)(NumbaJitclass)) delete_data_fn;
auto func = (void (*)(NumbaJitclass)) fn;
auto data = create_data();
std::cout << "nitems: " << ((ArrayModel*) data.data_ptr)->nitems << std::endl;
std::cout << "itemsize (bytes): " << ((ArrayModel*) data.data_ptr)->itemsize
<< std::endl;
std::cout << "first value before call: " << *((ArrayModel*) data.data_ptr)->val
<< std::endl;
func(data);
std::cout << "first value after call: " << *((ArrayModel*) data.data_ptr)->val
<< std::endl;
delete_data(data);
return 0;
});
}
test.pyimport numpy as np
import numba as nb
from numba.experimental.jitclass import _box
from testpy import call_fn_on_existing_data, call_fn_on_created_data, JitClass
@nb.experimental.jitclass(spec=[("a", nb.float64[::1])])
class TestData:
def __init__(self):
self.a = np.zeros(10)
def increment(data):
print(data.a[0])
data.a = np.ones(10)
print(data.a[0])
return 0
def create_data():
return TestData()
def get_cfunc(fn, signature):
jitted = nb.njit(fn)
compiled_fn = jitted.get_compile_result(signature)
return compiled_fn.library.get_pointer_to_function(
compiled_fn.fndesc.llvm_cfunc_wrapper_name
)
compiled_fn_ptr = get_cfunc(
increment,
nb.core.typing.Signature(
None,
(TestData.class_type.instance_type,),
None,
),
)
d = TestData()
jitclass = JitClass()
jitclass.meminfo = _box.box_get_meminfoptr(d)
jitclass.data = _box.box_get_dataptr(d)
call_fn_on_existing_data(compiled_fn_ptr, jitclass)
print("jitclass instance, 'a' member:", d.a)
create_data_fn_ptr = get_cfunc(
create_data,
nb.core.typing.Signature(
TestData.class_type.instance_type,
(),
None,
),
)
delete_data_fn_ptr = nb.core.runtime.nrt.rtsys.library.get_pointer_to_function(
"NRT_decref"
)
call_fn_on_created_data(create_data_fn_ptr, delete_data_fn_ptr, compiled_fn_ptr)
Cheers