Hi,
I have extended a custom class for numba and I’m trying to use this custom class as a key for numba typed Dict. This is my custom class extended for numba:
class Interval:
"""
No support for open, closedopen, openclosed
"""
def __init__(self, left, lower, upper, right):
self._lower = lower
self._upper = upper
self._left = left
self._right = right
@property
def lower(self):
return self._lower
@property
def upper(self):
return self._upper
@property
def name(self):
return hash(f'{self._left}{self._lower},{self._upper}{self._right}')
def to_str(self):
interval = f'{self._left}{self._lower},{self._upper}{self._right}'
return interval
def intersection(self, interval):
lower = max(self._lower, interval.lower)
upper = min(self._upper, interval.upper)
return Interval('[', lower, upper, ']')
def __contains__(self, item):
if self._lower <= item.lower and self._upper >= item.upper:
return True
else:
return False
def __eq__(self, interval):
if self.lower == interval.lower and self.upper == interval.upper:
return True
else:
return False
def __repr__(self):
return self.to_str()
def __lt__(self, other):
if self.upper < other.lower:
return True
else:
return False
def __le__(self, other):
if self.upper <= other.upper:
return True
else:
return False
def __gt__(self, other):
if self.lower > other.upper:
return True
else:
return False
def __ge__(self, other):
if self.lower >= other.lower:
return True
else:
return False
# def closed(lower, upper):
# return Interval('[', lower, upper, ']')
# def open(lower, upper):
# return Interval('(', lower, upper, ')')
# def closedopen(lower, upper):
# return Interval('[', lower, upper, ')')
# def openclosed(lower, upper):
# return Interval('(', lower, upper, ']')
from numba import types
from numba.extending import typeof_impl
from numba.extending import type_callable
from numba.extending import models, register_model
from numba.extending import make_attribute_wrapper
from numba.extending import overload_method, overload_attribute
from numba.extending import lower_builtin
from numba.core import cgutils
from numba.extending import unbox, NativeValue, box
# Create new numba type
class IntervalType(types.Type):
def __init__(self):
super(IntervalType, self).__init__(name='Interval')
interval_type = IntervalType()
# Type inference
@typeof_impl.register(Interval)
def typeof_node(val, c):
return interval_type
# Construct object from Numba functions
@type_callable(Interval)
def type_node(context):
def typer(left, low, up, right):
if isinstance(low, types.Float) and isinstance(up, types.Float) and isinstance(left, types.UnicodeType) and isinstance(right, types.UnicodeType):
return interval_type
return typer
# Define native representation: datamodel
@register_model(IntervalType)
class IntervalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [
('low', types.float64),
('up', types.float64),
('left', types.string),
('right', types.string),
]
models.StructModel.__init__(self, dmm, fe_type, members)
# Expose datamodel attributes
make_attribute_wrapper(IntervalType, 'low', 'low')
make_attribute_wrapper(IntervalType, 'up', 'up')
make_attribute_wrapper(IntervalType, 'left', 'left')
make_attribute_wrapper(IntervalType, 'right', 'right')
# Implement constructor
@lower_builtin(Interval, types.UnicodeType, types.Float, types.Float, types.UnicodeType)
def impl_node(context, builder, sig, args):
typ = sig.return_type
left, low, up, right = args
interval = cgutils.create_struct_proxy(typ)(context, builder)
interval.low = low
interval.up = up
interval.left = left
interval.right = right
return interval._getvalue()
# Expose properties
@overload_attribute(IntervalType, "lower")
def get_lower(interval):
def getter(interval):
return interval.low
return getter
@overload_attribute(IntervalType, "upper")
def get_upper(interval):
def getter(interval):
return interval.up
return getter
@overload_attribute(IntervalType, "name")
def get_upper(interval):
def getter(interval):
return hash(f'{interval.left}{interval.low},{interval.up}{interval.right}')
return getter
# Tell numba how to make native
@unbox(IntervalType)
def unbox_interval(typ, obj, c):
left_obj = c.pyapi.object_getattr_string(obj, "_left")
lower_obj = c.pyapi.object_getattr_string(obj, "_lower")
upper_obj = c.pyapi.object_getattr_string(obj, "_upper")
right_obj = c.pyapi.object_getattr_string(obj, "_right")
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder)
interval.left = c.unbox(types.string, left_obj).value
interval.low = c.unbox(types.float64, lower_obj).value
interval.up = c.unbox(types.float64, upper_obj).value
interval.right = c.unbox(types.string, right_obj).value
c.pyapi.decref(left_obj)
c.pyapi.decref(lower_obj)
c.pyapi.decref(upper_obj)
c.pyapi.decref(right_obj)
is_error = cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return NativeValue(interval._getvalue(), is_error=is_error)
@box(IntervalType)
def box_node(typ, val, c):
interval = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Interval))
left_obj = c.box(types.string, interval.left)
lower_obj = c.box(types.float64, interval.low)
upper_obj = c.box(types.float64, interval.up)
right_obj = c.box(types.string, interval.right)
res = c.pyapi.call_function_objargs(class_obj, (left_obj, lower_obj, upper_obj, right_obj))
c.pyapi.decref(left_obj)
c.pyapi.decref(lower_obj)
c.pyapi.decref(upper_obj)
c.pyapi.decref(right_obj)
c.pyapi.decref(class_obj)
return res
When I run the following code
import numba
d = numba.typed.Dict.empty(
key_type=interval_type,
value_type=numba.types.float64,
)
I get this error:
Traceback (most recent call last):
File "interval_type.py", line 240, in <module>
d = numba.typed.Dict.empty(
File "/home/dyuman/.local/lib/python3.8/site-packages/numba/typed/typeddict.py", line 101, in empty
return cls(dcttype=DictType(key_type, value_type))
File "/home/dyuman/.local/lib/python3.8/site-packages/numba/typed/typeddict.py", line 116, in __init__
self._dict_type, self._opaque = self._parse_arg(**kwargs)
File "/home/dyuman/.local/lib/python3.8/site-packages/numba/typed/typeddict.py", line 127, in _parse_arg
opaque = _make_dict(dcttype.key_type, dcttype.value_type)
File "/home/dyuman/.local/lib/python3.8/site-packages/numba/core/dispatcher.py", line 468, in _compile_for_args
error_rewrite(e, 'typing')
File "/home/dyuman/.local/lib/python3.8/site-packages/numba/core/dispatcher.py", line 409, in error_rewrite
raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function new_dict at 0x7f7a2f461280>) found for signature:
>>> new_dict(typeref[Interval], class(float64))
There are 2 candidate implementations:
- Of which 1 did not match due to:
Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
With argument(s): '(typeref[Interval], class(float64))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: native lowering)
No implementation of function Function(<built-in function eq>) found for signature:
>>> eq(Interval, Interval)
There are 30 candidate implementations:
- Of which 28 did not match due to:
Overload of function 'eq': File: <numerous>: Line N/A.
With argument(s): '(Interval, Interval)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'eq': File: unknown: Line unknown.
With argument(s): '(Interval, Interval)':
No match for registered cases:
* (bool, bool) -> bool
* (int8, int8) -> bool
* (int16, int16) -> bool
* (int32, int32) -> bool
* (int64, int64) -> bool
* (uint8, uint8) -> bool
* (uint16, uint16) -> bool
* (uint32, uint32) -> bool
* (uint64, uint64) -> bool
* (float32, float32) -> bool
* (float64, float64) -> bool
* (complex64, complex64) -> bool
* (complex128, complex128) -> bool
During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None, target=None)" at /home/dyuman/.local/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
raised from /home/dyuman/.local/lib/python3.8/site-packages/numba/core/types/functions.py:227
- Of which 1 did not match due to:
Overload in function 'impl_new_dict': File: numba/typed/dictobject.py: Line 639.
With argument(s): '(typeref[Interval], class(float64))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: native lowering)
No implementation of function Function(<built-in function eq>) found for signature:
>>> eq(Interval, Interval)
There are 30 candidate implementations:
- Of which 28 did not match due to:
Overload of function 'eq': File: <numerous>: Line N/A.
With argument(s): '(Interval, Interval)':
No match.
- Of which 2 did not match due to:
Operator Overload in function 'eq': File: unknown: Line unknown.
With argument(s): '(Interval, Interval)':
No match for registered cases:
* (bool, bool) -> bool
* (int8, int8) -> bool
* (int16, int16) -> bool
* (int32, int32) -> bool
* (int64, int64) -> bool
* (uint8, uint8) -> bool
* (uint16, uint16) -> bool
* (uint32, uint32) -> bool
* (uint64, uint64) -> bool
* (float32, float32) -> bool
* (float64, float64) -> bool
* (complex64, complex64) -> bool
* (complex128, complex128) -> bool
During: lowering "$20call_function.8 = call $12load_global.4(dp, $16load_deref.6, $18load_deref.7, func=$12load_global.4, args=[Var(dp, dictobject.py:653), Var($16load_deref.6, dictobject.py:654), Var($18load_deref.7, dictobject.py:654)], kws=(), vararg=None, target=None)" at /home/dyuman/.local/lib/python3.8/site-packages/numba/typed/dictobject.py (654)
raised from /home/dyuman/.local/lib/python3.8/site-packages/numba/core/types/functions.py:227
During: resolving callee type: Function(<function new_dict at 0x7f7a2f461280>)
During: typing of call at /home/dyuman/.local/lib/python3.8/site-packages/numba/typed/typeddict.py (23)
File "../../../../../../../.local/lib/python3.8/site-packages/numba/typed/typeddict.py", line 23:
def _make_dict(keyty, valty):
return dictobject._as_meminfo(dictobject.new_dict(keyty, valty))
How can I solve this and use my custom type as a key for the typed Dict?