Hello,
Based on the example provided by Danny Weitekamp here, I’ve been trying to wrap my head around using StrucRef to create a cacheable Octree class that reduces the compile time I was seeing using jitclass.
I think I’ve got most things relatively well understood from Danny’s example, but the one part I’m getting stuck on is using @overload
to enable using the class as a constructor in a jitted function.
My minimum reproducible example is below, and based on Danny’s example my understanding is that I should be able to create a new instance of a Cell
object in the jitted create_new_cell
function, but I’m getting a type error returned from the compiler. It works if I use new_cell = structref.new(CellType)
, but that doesn’t seem to align with what Danny was saying the @overload(Cell)
should be enabling me to do.
I can call Cell()
as a constructor outside of a jitted function just fine, it’s only any issue inside a jitted function.
Any guidance to help me fix this and understand where I’ve gone wrong (or correct my understanding if I’ve missunderstood something) greatly appreciated. At which point I’ll be happy to provide a minimialistic example for the Help doco for using StrucRef to create a Quadtree/Octree if people would find that useful?
import numpy as np
from numba import njit, types
from numba.extending import overload
from numba.experimental import structref
class Cell(structref.StructRefProxy):
def __new__(cls):
self = new_cell()
return self
@property
def center(self):
return cell_get_center(self)
@center.setter
def center(self, center):
return cell_set_center(self, center)
@property
def length(self):
return cell_get_length(self)
@length.setter
def length(self, length):
return cell_set_length(self, length)
@njit(cache=True)
def cell_get_center(self):
return self.center
@njit(cache=True)
def cell_set_center(self, center):
self.center = center
@njit(cache=True)
def cell_get_length(self):
return self.length
@njit(cache=True)
def cell_set_length(self, length):
self.length = length
@structref.register
class CellTypeTemplate(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
cell_fields = [
('length', types.float64),
('center', types.float64[:]),
]
structref.define_boxing(CellTypeTemplate, Cell)
# This is the actual resolved type
CellType = CellTypeTemplate(cell_fields)
@njit(cache=True)
def new_cell():
cell = structref.new(CellType)
return cell
@overload(Cell)
def overload_cell():
return new_cell
@njit(cache=True)
def create_new_cell(center, length):
new_cell = Cell()
new_cell : Cell
new_cell.center = center
new_cell.length = length
return new_cell
if __name__ == '__main__':
center = np.array([1., 5.])
length = 2.
my_cell = Cell()
my_cell.center = center
my_cell.length = length
my_cell_jitted = create_new_cell(center, length)