from numba import types
from numba.extending import typeof_impl
from numba.extending import as_numba_type
from numba.extending import type_callable
from numba.extending import models, register_model
from numba.extending import make_attribute_wrapper
from numba.extending import overload
from numba.extending import lower_builtin
from numba.core import cgutils
from numba.core.extending import overload_method
from numba.extending import unbox, NativeValue, box
from contextlib import ExitStack
from numba import njit
import operator
class Decimal(object):
"""
A half-open interval on the real number line.
"""
def __init__(self, f, e=0):
if f == 0:
self.f = 0
self.e = 0
elif e == 0:
e_ = int(math.log10(abs(f))//1)
f_ = f / 10 ** e_
self.f = f_
self.e = e_
else:
e_ = int(math.log10(abs(f))//1)
f_ = f / 10 ** e_
self.f = f_
self.e = e + e_
def __add__(self,other):
if isinstance(other, Decimal):
f1 = self.f
e1 = self.e
f2 = other.f
e2 = other.e
if e1 == e2:
summation = Decimal((f1 + f2), e1)
elif e1 > e2:
d = e1 - e2
f2 = f2 / 10 ** d
summation = Decimal((f1 + f2), e1)
else:
d = e2 - e1
f1 = f1 / 10 ** d
summation = Decimal((f1 + f2), e2)
if summation.f == 0:
return (0)
else:
return summation
elif isinstance(other, float) or isinstance(other, int):
other_ = Decimal(other)
return self + other_
else:
raise TypeError("Unsupported operand type")
def __radd__(self,other):
return self+other
def __sub__(self, other):
if isinstance(other, Decimal):
other_ = Decimal(-other.f, other.e)
return self + other_
elif isinstance(other, float) or isinstance(other, int):
other_ = Decimal(other)
return self - other_
else:
raise TypeError("Unsupported operand type")
def __mul__(self, other):
if isinstance(other, Decimal):
return(Decimal(self.f*other.f, self.e+other.e))
if isinstance(other, (int,float)):
return(Decimal(self.f*other, self.e))
def __rmul__(self, other):
return self * other
def __truediv__(self, other):
if other == 0:
raise TypeError("Divide by 0 occurred")
elif isinstance(other, Decimal):
other_ = Decimal(1 / other.f, -other.e)
return self * other_
elif isinstance(other, float):
return Decimal(self.f / other, self.e)
elif isinstance(other, int):
return Decimal(self.f / other, self.e)
else:
raise TypeError("Unsupported operand type")
def __rtruediv__(self, other):
f_ = 1 / self.f
e_ = -self.e
self_ = Decimal(f_, e_)
return self_ * other
def ln(self):
if self.f <= 0:
raise TypeError("Log is undefined for a negative value")
else:
log_f = math.log(self.f)
log_e = self.e * np.log(10)
return log_f + log_e
def __repr__(self):
return '%0.10fe%d' % (self.f, self.e)
class DecimalType(types.Type):
def __init__(self):
super(DecimalType, self).__init__(name='Decimal')
decimal_type = DecimalType()
@typeof_impl.register(Decimal)
def typeof_index(val, c):
return decimal_type
as_numba_type.register(Decimal, decimal_type)
@type_callable(Decimal)
def type_interval(context):
def typer(f, e):
if isinstance(f, types.Float) and isinstance(e, types.Integer):
return decimal_type
return typer
@register_model(DecimalType)
class DecimalModel(models.StructModel):
def __init__(self, dmm, fe_type):
members = [('f', types.float64),
('e', types.int64),]
models.StructModel.__init__(self, dmm, fe_type, members)
make_attribute_wrapper(DecimalType, 'f', 'f')
make_attribute_wrapper(DecimalType, 'e', 'e')
@lower_builtin(Decimal, types.Float, types.Integer)
def impl_interval(context, builder, sig, args):
typ = sig.return_type
f, e = args
deci = cgutils.create_struct_proxy(typ)(context, builder)
deci.f = f
deci.e = e
return deci._getvalue()
@lower_builtin('*')
def decimal_multiply(context, builder, sig, args):
[decimal_a, decimal_b] = args
f_a = context.make_helper(builder, DecimalModel, decimal_a).f
e_a = context.make_helper(builder, DecimalModel, decimal_a).e
f_b = context.make_helper(builder, DecimalModel, decimal_b).f
e_b = context.make_helper(builder, DecimalModel, decimal_b).e
f_result = builder.fmul(f_a, f_b)
e_result = builder.add(e_a, e_b)
result = context.make_helper(builder, DecimalModel, decimal_a)
result.f = f_result
result.e = e_result
return result._getvalue()
@overload(operator.mul)
def dec_mul(self,other):
if isinstance(other, Decimal):
def impl(self,other):
return(Decimal(self.f*other.f, self.e+other.e))
return impl
if isinstance(other, (int,float)):
def impl(self,other):
return(Decimal(self.f*other, self.e))
return impl
@unbox(DecimalType)
def unbox_interval(typ, obj, c):
#Convert a Decimal object to a native structure.
is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
deci = cgutils.create_struct_proxy(typ)(c.context, c.builder)
with ExitStack() as stack:
f_obj = c.pyapi.object_getattr_string(obj, "f")
with cgutils.early_exit_if_null(c.builder, stack, f_obj):
c.builder.store(cgutils.true_bit, is_error_ptr)
f_native = c.unbox(types.float64, f_obj)
c.pyapi.decref(f_obj)
with cgutils.early_exit_if(c.builder, stack, f_native.is_error):
c.builder.store(cgutils.true_bit, is_error_ptr)
e_obj = c.pyapi.object_getattr_string(obj, "e")
if e_obj is not None:
e_native = c.unbox(types.int64, e_obj)
else:
e_native = NativeValue(0.0, is_error=cgutils.false_bit)
c.pyapi.decref(e_obj)
with cgutils.early_exit_if(c.builder, stack, e_native.is_error):
c.builder.store(cgutils.true_bit, is_error_ptr)
deci.f = f_native.value
deci.e = e_native.value
return NativeValue(deci._getvalue(), is_error=c.builder.load(is_error_ptr))
@box(DecimalType)
def box_interval(typ, val, c):
#Convert a native structure to an Decimal object.
ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
fail_obj = c.pyapi.get_null_object()
with ExitStack() as stack:
deci = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
f_obj = c.box(types.float64, deci.f)
with cgutils.early_exit_if_null(c.builder, stack, f_obj):
c.builder.store(fail_obj, ret_ptr)
e_obj = c.box(types.int64, deci.e)
with cgutils.early_exit_if_null(c.builder, stack, e_obj):
c.pyapi.decref(f_obj)
c.builder.store(fail_obj, ret_ptr)
class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Decimal))
with cgutils.early_exit_if_null(c.builder, stack, class_obj):
c.pyapi.decref(f_obj)
c.pyapi.decref(e_obj)
c.builder.store(fail_obj, ret_ptr)
# NOTE: The result of this call is not checked as the clean up
# has to occur regardless of whether it is successful. If it
# fails `res` is set to NULL and a Python exception is set.
res = c.pyapi.call_function_objargs(class_obj, (f_obj, e_obj))
c.pyapi.decref(f_obj)
c.pyapi.decref(e_obj)
c.pyapi.decref(class_obj)
c.builder.store(res, ret_ptr)
return c.builder.load(ret_ptr)
d1 = Decimal(6.3242)
d2 = Decimal(5.3484)
ans=d1*d1 + 3*d2 + d1*3.52345 - d2
print(ans)
from numba import jit, njit
@jit(forceobj=True)
#@njit
def test(d1,d2):
ans = d1*d1 + 3*d2 + d1*3.52345 - d2
return ans
d1 = Decimal(6.3242)
d2 = Decimal(5.3484)
ans = test(d1,d2)
print(ans, type(ans))```
The code work fine for python compilation (@jit(forceobj=True)) but with @njit it cannot perform any operation. I am not sure why this is happening, Maybe I am overloading the operators wrong. Can someone please help with this?
The error I get with @njit is:
TypingError Traceback (most recent call last)
/tmp/ipykernel_100289/17093929.py in
252 d1 = Decimal(6.3242)
253 d2 = Decimal(5.3484)
→ 254 ans = test(d1,d2)
255 print(ans, type(ans))
~/.local/lib/python3.10/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
466 e.patch_message(msg)
467
→ 468 error_rewrite(e, ‘typing’)
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
~/.local/lib/python3.10/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
407 raise e
408 else:
→ 409 raise e.with_traceback(None)
410
411 argtypes =
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function() found for signature:
mul(Decimal, Decimal)
There are 24 candidate implementations:
- Of which 22 did not match due to:
Overload of function ‘mul’: File: : Line N/A.
With argument(s): ‘(Decimal, Decimal)’:
No match.
- Of which 2 did not match due to:
Operator Overload in function ‘mul’: File: unknown: Line unknown.
With argument(s): ‘(Decimal, Decimal)’:
No match for registered cases:
* (int64, int64) → int64
* (int64, uint64) → int64
* (uint64, int64) → int64
* (uint64, uint64) → uint64
* (float32, float32) → float32
* (float64, float64) → float64
* (complex64, complex64) → complex64
* (complex128, complex128) → complex128
During: typing of intrinsic-call at /tmp/ipykernel_100289/17093929.py (249)
File “…/…/…/…/…/…/tmp/ipykernel_100289/17093929.py”, line 249: