Exposing StructRef method to Python side

Hello,

I am testing a possible migration away from jitclass to structref so that I can take advantage of caching and AOT compilation. I’m wondering if this makes sense from a Numba roadmap perspective? Is there a plan for jitclass to support caching? Is structref here to stay?

Also, in my testing of structref I’ve encountered an issue that I’m not sure how to solve. I’ve managed to expose methods that are accessible in jit’d functions via @overload_method, but I can’t figure out how to expose them to the Python side?

Thanks for your help!

1 Like

I’m missing something… do you have a minimal example that shows what you’re trying to do?

In the documentation the name() method is callable from plain python?

The maintainers can comment on the future of jitclass but I converted from jitclass to structref a couple of years ago and haven’t regretted it.

Sure, here is an example. I did manage to figure out a solution but I’m not sure it is the best way to do it since the method is jitted twice.

My goal is to create a StructRef template for converting all of my jitclasses to structrefs such that:

  1. They emulate standard class functionality
  2. They can be used inside other structrefs
  3. They can be wrapped with @njit(cache=True)
  4. Methods called from pure Python are @njit’d and cached

Hoping to get some feedback on what I’ve come up with so far.

import numpy as np
from numba.core import types
from numba import njit, float64
from numba.experimental import structref
from numba.core.extending import overload_method

class MyClassType(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(typ)) for name, typ in fields)

class MyClass(structref.StructRefProxy):
    def __new__(cls, x):
        y = 2*x
        z = 4
        return structref.StructRefProxy.__new__(cls, x, y, z)
    
    @property
    def x(self):
        return get_x(self)
    
    @property
    def y(self):
        return get_y(self)
    
    @property
    def z(self):
        return get_z(self)
    
    # This method is available from pure Python
    def acc(self, a=1):
        return _acc(self, a)

@njit
def get_x(self):
    return self.x

@njit
def get_y(self):
    return self.y

@njit
def get_z(self):
    return self.z

# njit and cache the method implementation
@njit(cache=True)
def _acc(self, a=1):
    return self.x.sum() + a + self.y + self.z

# njit and expose when called from jitted code
@overload_method(MyClassType, "acc")
def overload_acc(self, a=1):
    return _acc.py_func

structref.define_proxy(MyClass, MyClassType, ["x", "y", "z"])

# Create an instance of MyClass
my_instance = MyClass(np.array([1.,2.,3.]))

# test that we can njit and cache a structref instance
@njit(cache=True)
def add(instance):
   return instance.acc(3)

print(add(my_instance)) # [15. 17. 19.]

# test access from Python
print(my_instance.y) # [2. 4. 6.]
print(my_instance.x) # [1. 2. 3.]
print(my_instance.acc(3)) # [15. 17. 19.]
1 Like

What you have looks pretty reasonable to me. @DannyWeitekamp’s CRE contains an absolute trove of great examples, including a structref generator that I used as a starting point for my own generator.

1 Like

Yes, looking around CRE is definitely a good way to find good patterns for using structref, many others have found it helpful. The devs can speak to their long-term plans for structref / jitclass, which I’m eager to hear, but the current state of things leaves something to be desired. Structref pretty much gives you free reign in terms of customizing functionality on both the python and jit side, but requires contending with an annoying amount of boiler plate that isn’t well documented, something that @nelson2005 and I have worked around in our own projects with specialized structref definition machinery. Jitclass has a bit of a friendlier syntax, but isn’t AOT or cache compatible which is a complete deal-breaker in larger projects. I’d love to see a solution that captures the best of both—or at least has the properties of being easy to pick up, but not limited in terms of customization, and AOT/cache friendliness.

A couple of conceptual notes to keep in mind going forward which I think the docs really ought to cover, because I’ve found the way the docs encourage you to use structref a bit limiting.

  • There are two pieces you need to define: 1) the subclass to types.StructRef (MyClassType in your case) is like a TypeClass or TypeTemplate, basically a meta-class for instantiating grounded types with fixed fields. You can often use this directly when writing @overload functions 2) The subclass of structref.StructRefProxy, which is the class for python side “Proxy” objects.
  • When you call define_proxy you are 1) internally calling define_constructor() which let’s your Proxy work as a constructor, and 2) calling define_boxing() which defines how your Proxy can be converted to a numba structref and vis versa. In many places in CRE I call one or both of these directly because I don’t need both, or want to customize one or the other.
  • The docs don’t actually cover how to specialize a TypeClass into a grounded type, but doing so can save you a lot of headaches, since you can control the data layout of the type instead of having numba infer it from a constructor call, which can have you scratching your head when it mysteriously generates multiple struct types when you thought you just had one. Basically just pass the fields to it in the same format you would use to define a jitclass:
    MyType = MyTypeClass([ ('name', unicode_type), ('x', f8), ('y', f8) ] )
  • Then you can have a bit more control, as in jitclass, over the precise types you’re using, and can even write constructors that are a bit more verbose. For instance, here is a jitted constructor that I wrote in CRE:
@njit(GenericFactIteratorType(MemSetType,i8[::1]),cache=True)
def generic_fact_iterator_ctor(ms, t_ids):
    st = new(GenericFactIteratorType)
    st.memset = ms
    st.t_ids = t_ids
    st.curr_ind = 0
    st.curr_t_id_ind = 0
    return st

Note the use of new which is from numba.experimental.structref. It makes an empty structref instance. This method allows us to be explicit about the signature (helpful for AOT) since the output type is well defined, and we can cache the jitted constructor. A point of warning: if we didn’t set all of the member types in this constructor then we could in principle dereference an empty object field which would cause a null pointer exception (which you would see as a segfault). This is why the docs don’t encourage this usage pattern.

One last note. I don’t know if @nelson2005 has found a better way, but I’ve found that in practice the only reliable way to cache your structref definitions is to always have them written to a file so the source is always well defined. So if you were planning on making some kind of custom structref generator, and you want that code to cache or AOT properly, you’ll also need some kind of file cache machinery so that your generated code always lives somewhere concrete. Feel free to snag what I have in CRE for that. Sorry all of this isn’t simpler! I’d really love to see some more attention given to this in the future. The good news is all of the pieces you’ll need exist if you’re willing to piece them together. (I think this covers 2,3,4 on your list but maybe not 1)

5 Likes

I haven’t found a better way but it hasn’t really been an issue for my use case. My generated structref definitions aren’t particularly dynamic so I can generate them once and then keep them in the git repo like any other file.

This is extremely helpful - thank you!

Ok, so trying out the method you suggested for concrete types but for some reason I can’t use the constructor on the Python side…what am I doing wrong?

@structref.register
class MyClassType(types.StructRef):
    pass

MyType = MyClassType([("x",f8[:])])
@njit(MyType(f8[:]))
def my_class_ctor(x):
    st = structref.new(MyType)
    st.x = x
    return st

class MyClass(structref.StructRefProxy):
    def __new__(cls, x):
        self = my_class_ctor(x)
        return self

    @property
    def x(self):
        return _x(self)

@njit(cache=True)
def _x(self):
    return self.x

@overload(MyClass)
def overload_MyClass(x):
    def impl(x):
        return my_class_ctor(x)
    return impl

structref.define_boxing(MyClassType, MyClass)

# This works...
@njit
def test():
    my_instance = MyClass(np.array([1.0, 2.0, 3.0]))
    print(my_instance.x) # prints [1. 2. 3.]
test()

# But this gets:
# TypeError: cannot convert native 
# numba.MyClassType(('x', array(float64, 1d, A)),) to Python object
my_instance = MyClass(np.array([1.0, 2.0, 3.0]))

If you are going to give your _ctor() function an explicit signature then you need to move it to after the call to define_boxing (it will still work fine in __new__() if you do that). The reason is that the boxing/unboxing machinery is compiled directly into the function at compile time so it needs to be set up before the _ctor() function is compiled. When you provide a signature the function it is compiled with that signature at definition instead of just-in-time.

2 Likes

Do you mind posting your final demo here after you getting going? That could be pretty helpful to others walking this road.

Absolutely!

I’m still working on it, but here is what I have so far that works.

import numpy as np
from numba.pycc import CC
from numba import njit, f8
from numba.experimental import structref
from numba.core.extending import overload_method, overload

# This is boilerplate that can be imported from elsewhere and re-used
def create_type_template(cls):
    source = f"""
from numba.core import types
from numba.experimental import structref
@structref.register
class {cls.__name__}Type(types.StructRef):
    pass
"""
    glbs = globals()
    exec(source, glbs)

    return glbs[f"{cls.__name__}Type"]

# Create a Python class that will be a proxy for the Numba class - the actual
# implementation is not defined here.
class MyClass(structref.StructRefProxy):
    def __new__(cls, x):
        self = my_class_constructor(x)
        return self

    @property
    def x(self):
        return _x(self)

    @property
    def y(self):
        return _y(self)

    def acc(self, a=1):
        return _acc(self, a)


# Create a type for the class and define Numba-to-Python interfacing (boxing)
MyClassTemplate = create_type_template(MyClass)
MyType = MyClassTemplate([("x", f8[:]), ("y", f8[:])])
structref.define_boxing(MyClassTemplate, MyClass)

# Define the typed constructor implementation
@njit(MyType(f8[:]), cache=True)
def my_class_constructor(x):
    self = structref.new(MyType)
    self.x = x
    self.y = 2 * x
    return self


# Overload the Python constructor with our Numba implementation
@overload(MyClass)
def overload_MyClass(x):
    def implementation(x):
        return my_class_constructor(x)

    return implementation


# Implementations of getters/setters/methods
@njit(cache=True)
def _x(self):
    return self.x


@njit(cache=True)
def _y(self):
    return self.y


@njit(cache=True)
def _acc(self, a=1):
    return self.x.sum() + a + self.y


# Extra step required for methods to expose in jit-code
@njit(cache=True)
@overload_method(MyClassTemplate, "acc")
def overload_acc(self, a=1):
    python_implementation = _acc.py_func
    return python_implementation


# Test:
# running this script prints
# [12. 14. 16.]
# [2. 4. 6.]
# [1. 2. 3.]
# [11. 13. 15.]
# [12. 14. 16.]
# [2. 4. 6.]
# [1. 2. 3.]
# [11. 13. 15.]
# AOT compiling
# [12. 14. 16.]
if __name__ == "__main__":

    # Create a function that operates on an instance of our new class and
    # ahead-of-time (AOT) and JIT compile it
    cc = CC("my_module")

    @njit(f8[:](MyType), cache=True)
    @cc.export("add", f8[:](MyType))
    def add(instance):
        return instance.acc(4)

    # Create a an instance of MyClass and test it in Python
    my_instance = MyClass(np.array([1.0, 2.0, 3.0]))
    print(add(my_instance))
    print(my_instance.y)
    print(my_instance.x)
    print(my_instance.acc(3))

    # Test in jit-code
    @njit(cache=True)
    def test_jit_code():
        my_instance = MyClass(np.array([1.0, 2.0, 3.0]))
        print(add(my_instance))
        print(my_instance.y)
        print(my_instance.x)
        print(my_instance.acc(3))

    test_jit_code()

    # Test it in AOT code
    print("AOT compiling")
    cc.compile()

    def test_aot_code():
        import my_module

        print(my_module.add(my_instance))

    test_aot_code()

3 Likes