I’m trying to use a custom type with Numba, following the “Interval” tutorial in the docs. My type looks like:
class Pose:
"""Representation of a SE(3) pose."""
__slots__ = ('r', 't')
def __init__(self, r: NDArray, t: NDArray):
"""Pose constructor."""
self.r = r
self.t = t
@property
def t_(self) -> NDArray:
"""Get the 3D translation of the pose."""
return self.t[:3]
def __repr__(self) -> str:
return f'Pose(r={self.r}, t={self.t_})'
Following the aforementioned tutorial, I’ve added code to register the type and lower/box/unbox:
class PoseT(types.Type):
def __init__(self):
self.r = types.Array(types.float64, 1, 'C')
self.t = types.Array(types.float64, 1, 'C')
super(PoseT, self).__init__(name = 'Pose')
pose_t = PoseT()
@typeof_impl.register(Pose)
def typeof_index(val, c):
return pose_t
as_numba_type.register(Pose, pose_t)
@type_callable(Pose)
def type_interval(context):
def typer(r, t):
if isinstance(r, types.Array) and isinstance(t, types.Array):
return pose_t
return typer
@register_model(PoseT)
class PoseModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [('r', fe_type.r), ('t', fe_type.t)]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(PoseT, 'r', 'r')
make_attribute_wrapper(PoseT, 't', 't')
@overload_attribute(PoseT, 't_')
def get_t_(pose):
def getter(pose):
return pose.t[:3]
return getter
@lower_builtin(Pose, types.Array(types.float64, 1, 'C'), types.Array(types.float64, 1, 'C'))
def impl_pose(context, builder, sig, args):
typ = sig.return_type
r, t = args
pose = cgutils.create_struct_proxy(typ)(context, builder)
pose.r = r
pose.t = t
return pose._getvalue()
@unbox(PoseT)
def unbox_pose(typ, obj, c):
r_obj = c.pyapi.object_getattr_string(obj, 'r')
t_obj = c.pyapi.object_getattr_string(obj, 't')
pose = cgutils.create_struct_proxy(typ)(c.context, c.builder)
pose.r = unbox_array(types.Array(types.float64, 1, 'C'), r_obj, c).value
pose.t = unbox_array(types.Array(types.float64, 1, 'C'), t_obj, c).value
c.pyapi.decref(r_obj)
c.pyapi.decref(t_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(pose._getvalue(), is_error = is_error)
@box(PoseT)
def box_pose(typ, val, c):
pose = cgutils.create_struct_proxy(typ)(c.context, c.builder, value = val)
r_obj = box_array(types.Array(types.float64, 1, 'C'), pose.r, c)
t_obj = box_array(types.Array(types.float64, 1, 'C'), pose.t, c)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Pose))
res = c.pyapi.call_function_objargs(class_obj, (r_obj, t_obj))
c.pyapi.decref(r_obj)
c.pyapi.decref(t_obj)
c.pyapi.decref(class_obj)
return res
I’m trying to use this type as a return type only, e.g.:
@njit
def test_pose(r, t):
return Pose(r, t)
I’m encountering two errors that imply I’m doing something wrong with boxing (probably)…
- I can call one JIT’d function that returns
Pose
, and it will work as expected. However, the second time I call it or any other JIT’d function that returnsPose
, I get an errormalloc(): unaligned tcache chunk detected
. - When I try to JIT my actual use-case function (which does a large amount of math, then takes the results and passes them to an expression like
return Pose(np.array([x, y, z, w]), np.array([tx, ty, tz]))
), the returned values are garbage - every time, I get back exactlyPose(r=[-9.86830992e+148 -9.86830992e+148 -9.86830992e+148 -9.86830992e+148], t=[-9.86830992e+148 -9.86830992e+148 -9.86830992e+148])
. I have tested this function without the JIT, and it works fine in that setting.
I’m taking this approach over a jitclass
or StructRef
because I would like to keep the main Pose
type independent of Numba (I intend to also support other JIT backends, such as Jax).
I’m at a loss for how to proceed with debugging either of these issues - any help would be greatly appreciated. Thank you!