StructRef example guidance

Hello,

Based on the example provided by Danny Weitekamp here, I’ve been trying to wrap my head around using StrucRef to create a cacheable Octree class that reduces the compile time I was seeing using jitclass.

I think I’ve got most things relatively well understood from Danny’s example, but the one part I’m getting stuck on is using @overload to enable using the class as a constructor in a jitted function.

My minimum reproducible example is below, and based on Danny’s example my understanding is that I should be able to create a new instance of a Cell object in the jitted create_new_cell function, but I’m getting a type error returned from the compiler. It works if I use new_cell = structref.new(CellType), but that doesn’t seem to align with what Danny was saying the @overload(Cell) should be enabling me to do.

I can call Cell() as a constructor outside of a jitted function just fine, it’s only any issue inside a jitted function.

Any guidance to help me fix this and understand where I’ve gone wrong (or correct my understanding if I’ve missunderstood something) greatly appreciated. At which point I’ll be happy to provide a minimialistic example for the Help doco for using StrucRef to create a Quadtree/Octree if people would find that useful?

import numpy as np
from numba import njit, types
from numba.extending import overload
from numba.experimental import structref

class Cell(structref.StructRefProxy):
    def __new__(cls):
        self = new_cell()
        return self

    @property
    def center(self):
        return cell_get_center(self)
    
    @center.setter
    def center(self, center):
        return cell_set_center(self, center)
    
    @property
    def length(self):
        return cell_get_length(self)
    
    @length.setter
    def length(self, length):
        return cell_set_length(self, length)


@njit(cache=True)
def cell_get_center(self):
    return self.center

@njit(cache=True)
def cell_set_center(self, center):
    self.center = center

@njit(cache=True)
def cell_get_length(self):
    return self.length

@njit(cache=True)
def cell_set_length(self, length):
    self.length = length

@structref.register
class CellTypeTemplate(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(typ)) for name, typ in fields)
    
cell_fields = [
    ('length', types.float64),
    ('center', types.float64[:]),
]

structref.define_boxing(CellTypeTemplate, Cell)

# This is the actual resolved type 
CellType = CellTypeTemplate(cell_fields)

@njit(cache=True)
def new_cell():
    cell = structref.new(CellType)
    return cell

@overload(Cell)
def overload_cell():
    return new_cell


@njit(cache=True)
def create_new_cell(center, length):
    new_cell = Cell()
    new_cell : Cell

    new_cell.center = center
    new_cell.length = length

    return new_cell


if __name__ == '__main__':
    center = np.array([1., 5.])
    length = 2.

    my_cell = Cell()
    my_cell.center = center
    my_cell.length = length

    my_cell_jitted = create_new_cell(center, length)

Nearly there. Your @overload should probably look something like this:

@overload(Cell)
def overload_cell():
    def impl():
        return new_cell()
    return impl

Functions decorated by @overload must return a function implementation (e.g. impl) which will be compiled. The reason @overload is designed this way to allow you to specialized the implementation of the alias you are overloading (e.g. Cell) depending on the types of arguments that are provided.

If you wanted to make your object by passing arguments instead of by using setters you could also do something like this:


# For the python side
class Cell(structref.StructRefProxy):
    def __new__(cls, center, length):
        self = new_cell(center, length)
        return self

#...and the rest

@njit(CellType(types.f8[::1], types.f8), cache=True)
def new_cell(center, length):
    cell = structref.new(CellType)
    cell.center = center
    cell.length = length
    return cell

# For the numba side
@overload(Cell)
def overload_cell(center, length):
    def impl(center, length):
        return new_cell(center, length)
    return impl


@njit(CellType(types.f8[::1], types.f8), cache=True)
def create_new_cell(center, length):
    new_cell = Cell(center, length)
    return new_cell


my_cell_jitted = create_new_cell(np.array([1., 5.], 2.))

In the above I needed to give njit the type specifications explicitly. This is in part because there is some finickiness when it comes to array types which can throw the typing system off. I typically always set array types as explicitly C contiguous (i.e. [::1]) instead of as ambiguous (i.e. [:]). It’s usually good to do that in your structref specifications as well:

cell_fields = [
    ('length', types.float64),
    ('center', types.float64[::1]),
]

Awesome, thanks very much Danny - that clears things up and works a treat.