How can I remotely call the method of the @jitclass instance by pointer

Hello! I want to figure out if it’s possible for me to somehow pass a pointer to another language (Rust in my case) to a function that would call the method of the @jitclass instance, get a result, and return that result to Rust. If possible (but doubt it) somehow call the method directly from Rust.

Here is a snippet of my code

spec = [
    ('value', int64)
]

@jitclass(spec)
class MyJitClass:
    def __init__(self, value):
        self.value = value

    def compute(self, x):
        return self.value * x

my_instance = MyJitClass(42)

@cfunc("int64(int64)")
def compute_wrapper(f):
    res = my_instance.compute(f)
    return res

I then get address of that cfunc and pass it to rust
but it fails to compile when I run the code with error complaining about calling jitclass from cfunc:

numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)

<numba.core.base.OverloadSelector object at 0x73b13275af60>, (instance.jitclass.MyJitClass#73b157350d70value:int64,another_value:int64,)

During: lowering “$4load_global.0 = global(my_instance: <numba.experimental.jitclass.boxing.MyJitClass object at 0x73b14cdddf60>)” at …/Projects/rustback/python/numba_run.py (65)

During: Pass native_lowering

Because this code works like a charm:

@njit(nogil=True)
def njit_func(a):
    return a + 5

@cfunc("int64(int64)")
def compute_wrapper(f):
    return njit_func(f)

but when I try to introduce jitclass it falls apart. Approach like here worked with Cython cclass and wonder if its possible to make it work with Numba as well

Yeah, you can do this. By default numba treats globals as compile time constants but you can get a mutable reference from the NRT.

There are some examples floating around the net.

Something like this should work. At least it does for structref.

@intrinsic
def _set_global(typingctx, _typ, _name, val):
    typ = _typ.instance_type
    name = _name.literal_value

    def codegen(context, builder, sig, args):
        val = args[-1]
        gv = get_or_make_global(context, builder, typ, name)
        builder.store(val, gv)

    sig = types.void(_typ, _name, val)
    return sig, codegen


@intrinsic
def _get_global(typingctx, _typ, _name):
    typ = _typ.instance_type
    name = _name.literal_value

    def codegen(context, builder, sig, args):
        gv = get_or_make_global(context, builder, typ, name)
        v = builder.load(gv)
        context.nrt.incref(builder, typ, v)
        return v

    sig = typ(_typ, _name)
  return sig, codegen
 
  @njit
  def set_global_container(container):
      _set_global(ModelStaticType, 'static_struct_bfe21224b463c22d1a8d12aa07a551bc', container)
 
 
  set_global_container(static)

 
  @njit
  def get_static_struct():
      return _get_global(ModelStaticType, 'static_struct_bfe21224b463c22d1a8d12aa07a551bc')

2 Likes