Thank you for your reply! I can post a complete program.
First, this is my definition of MyStruct and MyStructType (which is adapted so that types can be indicated explicitly for each field of MyStruct–I couldn’t find a simpler way to do so in the documentation):
import numpy as np
from numba import njit, float64, int32
from numba.experimental.structref import new, register, StructRefProxy, define_boxing
from numba.core.extending import overload
from numba.core.types import StructRef, Array
def my_struct(name, fields):
### step 1: register the new class of name MyStructType.
source = f'''
@register
class {name}Type(StructRef):
pass
'''
# print(source)
glbs = globals()
exec(source, glbs)
### step 2: declare the corresponding class MyStruct, in the same way as done at
### https://numba.pydata.org/numba-doc/dev/extending/high-level.html : we need to redefine
### __new__ to better match the required fields, as well as manually define getters for the fields.
pars = ', '.join([f"{f[0]}" for f in fields])
def proper(par):
return f'''
@property
def {par}(self):
return get_{par}(self)
'''
def njiter(par):
return f'''
@njit
def get_{par}(self):
return self.{par}
'''
propers = '\n'.join([proper(f[0]) for f in fields])
njiters = '\n'.join([njiter(f[0]) for f in fields])
source = f'''
class {name}(StructRefProxy):
def __new__(cls, {pars}):
return StructRefProxy.__new__(cls, {pars})
{propers}
{njiters}
'''
# print(source)
exec(source, glbs)
### step 3: we write the new constructor function for MyStruct, which calls structref.new to
### instantiate a new struct of type ConfigType defined above. This essentially replaces structref.define_constructor.
### We use extending.overload to set the new function as the constructor for MyStruct.
### Finally, we call define_boxing to make sure that the new struct can be used in both njit and
### normal python code. This is important because StructRef requires the types to be numba types,
### while we might want to instantiate a new struct in python code with numpy types.
def ster(par):
return f'''
st.{par} = {par}
'''
sters = '\n'.join([ster(f[0]) for f in fields])
source = f"""
def ctor({pars}):
struct_type = {name}Type({fields})
def impl({pars}):
st = new(struct_type)
{sters}
return st
return impl
"""
# print(source)
exec(source, glbs)
new = glbs[name]
newtype = glbs[name+"Type"]
overload(new)(glbs['ctor'])
define_boxing(newtype, new)
return new, newtype
Then, by calling
import numba.typed
from numba.core import types
MyStruct, MyStructType = my_struct("MyStruct", [
("a", float64),
("b", float64),
])
@njit
def f1(struct):
return struct.a + struct.b
struct = MyStruct(1,2)
# f1 works as expected
print(f1(struct))
f_list = numba.typed.List.empty_list(types.float64(MyStructType).as_type())
f_list.append(f1)
we see that MyStruct and f1 work as expected, but when trying to create the list I get
/usr/local/lib/python3.10/dist-packages/numba/core/types/scalars.py in cast_python_value(self, value)
125
126 def cast_python_value(self, value):
--> 127 return getattr(np, self.name)(value)
128
129 def __lt__(self, other):
TypeError: float() argument must be a string or a real number, not '_TypeMetaclass'
After your comment, I tried putting the signature in the @njit explicitly and that also didn’t work–so I guess it is just a matter of knowing the right way of expressing the StructRef in the signature. I have no idea how to do that, though. I’m also not sure how to print the nopython signatures of the function as you suggested.