Defining custom data type for @njit

from numba import types
from numba.extending import typeof_impl
from numba.extending import as_numba_type
from numba.extending import type_callable
from numba.extending import models, register_model
from numba.extending import make_attribute_wrapper
from numba.extending import overload
from numba.extending import lower_builtin
from numba.core import cgutils
from numba.core.extending import overload_method
from numba.extending import unbox, NativeValue, box
from contextlib import ExitStack
from numba import njit
import operator

class Decimal(object):
    """
    A half-open interval on the real number line.
    """
    def __init__(self, f, e=0):
        if f == 0:
            self.f = 0
            self.e = 0
        elif e == 0:
            e_ = int(math.log10(abs(f))//1)
            f_ = f / 10 ** e_
            self.f = f_
            self.e = e_
        else:
            e_ = int(math.log10(abs(f))//1)
            f_ = f / 10 ** e_
            self.f = f_
            self.e = e + e_   
       
    def __add__(self,other):        
        if isinstance(other, Decimal):
            f1 = self.f
            e1 = self.e
            f2 = other.f
            e2 = other.e
            if e1 == e2:
                summation = Decimal((f1 + f2), e1)
            elif e1 > e2:
                d = e1 - e2
                f2 = f2 / 10 ** d
                summation = Decimal((f1 + f2), e1)
            else:
                d = e2 - e1
                f1 = f1 / 10 ** d
                summation = Decimal((f1 + f2), e2)
            if summation.f == 0:
                return (0)
            else:
                return summation
        elif isinstance(other, float) or isinstance(other, int):
            other_ = Decimal(other)
            return self + other_
        else:
            raise TypeError("Unsupported operand type")

    def __radd__(self,other):
        return self+other
    
    def __sub__(self, other):
        if isinstance(other, Decimal):
            other_ = Decimal(-other.f, other.e)
            return self + other_
        elif isinstance(other, float) or isinstance(other, int):
            other_ = Decimal(other)
            return self - other_
        else:
            raise TypeError("Unsupported operand type")
    
    def __mul__(self, other):
        if isinstance(other, Decimal):
            return(Decimal(self.f*other.f, self.e+other.e))
        if isinstance(other, (int,float)):
            return(Decimal(self.f*other, self.e))
    
    def __rmul__(self, other):
        return self * other    
   
    def __truediv__(self, other):
        if other == 0:
            raise TypeError("Divide by 0 occurred")
        elif isinstance(other, Decimal):
            other_ = Decimal(1 / other.f, -other.e)
            return self * other_
        elif isinstance(other, float):
            return Decimal(self.f / other, self.e)
        elif isinstance(other, int):
            return Decimal(self.f / other, self.e)
        else:
            raise TypeError("Unsupported operand type")
    
    def __rtruediv__(self, other):
        f_ = 1 / self.f
        e_ = -self.e
        self_ = Decimal(f_, e_)
        return self_ * other

    def ln(self):
        if self.f <= 0:
            raise TypeError("Log is undefined for a negative value")
        else:
            log_f = math.log(self.f)
            log_e = self.e * np.log(10)
            return log_f + log_e
    
    def __repr__(self):
        return '%0.10fe%d' % (self.f, self.e)    

class DecimalType(types.Type):
    def __init__(self):
        super(DecimalType, self).__init__(name='Decimal')

decimal_type = DecimalType()

@typeof_impl.register(Decimal)
def typeof_index(val, c):
    return decimal_type

as_numba_type.register(Decimal, decimal_type)

@type_callable(Decimal)
def type_interval(context):
    def typer(f, e):
        if isinstance(f, types.Float) and isinstance(e, types.Integer):
            return decimal_type
    return typer

@register_model(DecimalType)
class DecimalModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [('f', types.float64),
                   ('e', types.int64),]
        models.StructModel.__init__(self, dmm, fe_type, members)
        
make_attribute_wrapper(DecimalType, 'f', 'f')
make_attribute_wrapper(DecimalType, 'e', 'e')

@lower_builtin(Decimal, types.Float, types.Integer)
def impl_interval(context, builder, sig, args):
    typ = sig.return_type
    f, e = args
    deci = cgutils.create_struct_proxy(typ)(context, builder)
    deci.f = f
    deci.e = e
    return deci._getvalue() 

@lower_builtin('*')
def decimal_multiply(context, builder, sig, args):
    [decimal_a, decimal_b] = args
    f_a = context.make_helper(builder, DecimalModel, decimal_a).f
    e_a = context.make_helper(builder, DecimalModel, decimal_a).e
    f_b = context.make_helper(builder, DecimalModel, decimal_b).f
    e_b = context.make_helper(builder, DecimalModel, decimal_b).e
    f_result = builder.fmul(f_a, f_b)
    e_result = builder.add(e_a, e_b)
    result = context.make_helper(builder, DecimalModel, decimal_a)
    result.f = f_result
    result.e = e_result
    return result._getvalue()

@overload(operator.mul)
def dec_mul(self,other):    
    if isinstance(other, Decimal):
        def impl(self,other):
            return(Decimal(self.f*other.f, self.e+other.e))
        return impl
    if isinstance(other, (int,float)):
        def impl(self,other):
            return(Decimal(self.f*other, self.e))
        return impl

@unbox(DecimalType)
def unbox_interval(typ, obj, c):    
    #Convert a Decimal object to a native structure.    
    is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
    deci = cgutils.create_struct_proxy(typ)(c.context, c.builder)

    with ExitStack() as stack:
        f_obj = c.pyapi.object_getattr_string(obj, "f")
        with cgutils.early_exit_if_null(c.builder, stack, f_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        f_native = c.unbox(types.float64, f_obj)
        c.pyapi.decref(f_obj)
        with cgutils.early_exit_if(c.builder, stack, f_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        e_obj = c.pyapi.object_getattr_string(obj, "e")        
        if e_obj is not None:
            e_native = c.unbox(types.int64, e_obj)
        else:
            e_native = NativeValue(0.0, is_error=cgutils.false_bit)
        c.pyapi.decref(e_obj)
        with cgutils.early_exit_if(c.builder, stack, e_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        deci.f = f_native.value
        deci.e = e_native.value

    return NativeValue(deci._getvalue(), is_error=c.builder.load(is_error_ptr))

@box(DecimalType)
def box_interval(typ, val, c):    
    #Convert a native structure to an Decimal object.    
    ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
    fail_obj = c.pyapi.get_null_object()
    
    with ExitStack() as stack:
        deci = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
        f_obj = c.box(types.float64, deci.f)
        with cgutils.early_exit_if_null(c.builder, stack, f_obj):
            c.builder.store(fail_obj, ret_ptr)

        e_obj = c.box(types.int64, deci.e)
        with cgutils.early_exit_if_null(c.builder, stack, e_obj):
            c.pyapi.decref(f_obj)
            c.builder.store(fail_obj, ret_ptr)

        class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Decimal))
        with cgutils.early_exit_if_null(c.builder, stack, class_obj):
            c.pyapi.decref(f_obj)
            c.pyapi.decref(e_obj)
            c.builder.store(fail_obj, ret_ptr)

        # NOTE: The result of this call is not checked as the clean up
        # has to occur regardless of whether it is successful. If it
        # fails `res` is set to NULL and a Python exception is set.
        res = c.pyapi.call_function_objargs(class_obj, (f_obj, e_obj))
        c.pyapi.decref(f_obj)
        c.pyapi.decref(e_obj)
        c.pyapi.decref(class_obj)
        c.builder.store(res, ret_ptr)

    return c.builder.load(ret_ptr)

d1 = Decimal(6.3242)
d2 = Decimal(5.3484) 

ans=d1*d1 + 3*d2 + d1*3.52345 - d2
print(ans)

from numba import jit, njit
@jit(forceobj=True)
#@njit
def test(d1,d2):
    ans = d1*d1 + 3*d2 + d1*3.52345 - d2
    return ans
        
d1 = Decimal(6.3242)
d2 = Decimal(5.3484)           
ans = test(d1,d2)
print(ans, type(ans))```

The code work fine for python compilation (@jit(forceobj=True)) but with @njit it cannot perform any operation. I am not sure why this is happening, Maybe I am overloading the operators wrong. Can someone please help with this?

The error I get with @njit is: 

TypingError Traceback (most recent call last)
/tmp/ipykernel_100289/17093929.py in
252 d1 = Decimal(6.3242)
253 d2 = Decimal(5.3484)
→ 254 ans = test(d1,d2)
255 print(ans, type(ans))

~/.local/lib/python3.10/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws)
466 e.patch_message(msg)
467
→ 468 error_rewrite(e, ‘typing’)
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info

~/.local/lib/python3.10/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type)
407 raise e
408 else:
→ 409 raise e.with_traceback(None)
410
411 argtypes =

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function() found for signature:

mul(Decimal, Decimal)

There are 24 candidate implementations:
- Of which 22 did not match due to:
Overload of function ‘mul’: File: : Line N/A.
With argument(s): ‘(Decimal, Decimal)’:
No match.
- Of which 2 did not match due to:
Operator Overload in function ‘mul’: File: unknown: Line unknown.
With argument(s): ‘(Decimal, Decimal)’:
No match for registered cases:
* (int64, int64) → int64
* (int64, uint64) → int64
* (uint64, int64) → int64
* (uint64, uint64) → uint64
* (float32, float32) → float32
* (float64, float64) → float64
* (complex64, complex64) → complex64
* (complex128, complex128) → complex128

During: typing of intrinsic-call at /tmp/ipykernel_100289/17093929.py (249)

File “…/…/…/…/…/…/tmp/ipykernel_100289/17093929.py”, line 249:

```

Your example is quite involved so I don’t have time to fully fix it up now, but I think the problem lies in your dec_mul() function not handling all cases correctly - I have:

@overload(operator.mul)
def dec_mul(self, other):
    if isinstance(self, DecimalType):
        if isinstance(other, DecimalType):
            def impl(self, other):
                return (Decimal(self.f*other.f, self.e+other.e))
            return impl
        elif isinstance(other, (types.Integer, types.Float)):
            def impl(self, other):
                return (Decimal(self.f*other, self.e))
            return impl
    elif isinstance(other, DecimalType) and isinstance(self, (types.Integer,
                                                              types.Float)):
        def impl(self, other):
            return (Decimal(other.f*self, other.e))
        return impl

to handle the Decimal * Decimal, Decimal * scalar and scalar * Decimal cases, and it seems to get past that - similar would need doing for operator.add and operator.sub to complete the implementation of arithmetic used in your function.

Note also your code as listed above doesn’t run - there are a couple of missing imports, so I took a guess at the fixups.

There might be other issues too, but I haven’t looked into whether there might be.

Thanks for the reply, but I think operator overloading is quite a headache for any custom data type at this moment if you want to use this with @cuda.jit or other functionalities. I have found another way to deal with it, leaving it here if someone needs it.

from numba import types
from numba.extending import typeof_impl
from numba.extending import as_numba_type
from numba.extending import type_callable
from numba.extending import models, register_model
from numba.extending import make_attribute_wrapper
from numba.extending import overload
from numba.extending import lower_builtin
from numba.core import cgutils
from numba.core.extending import overload_method
from numba.extending import unbox, NativeValue, box
from contextlib import ExitStack
from numba import njit
import math

class Decimal(object):    
    def __init__(self, f, e=None):
        if e == None:
            e = 0
        if abs(f) < 1e-10:
            self.f = 0
            self.e = 0
        elif e == 0:
            e_ = int(math.log10(abs(f))//1)
            f_ = f / 10 ** e_
            self.f = f_
            self.e = e_
        else:
            e_ = int(math.log10(abs(f))//1)
            f_ = f / 10 ** e_
            self.f = f_
            self.e = e + e_      
            
    def __repr__(self):
        return '%0.5fe%d' % (self.f, self.e)    

class DecimalType(types.Type):
    def __init__(self):
        super(DecimalType, self).__init__(name='Decimal')

@typeof_impl.register(Decimal)
def typeof_index(val, c):
    return DecimalType()

as_numba_type.register(Decimal, DecimalType())

@type_callable(Decimal)
def type_interval(context):
    def typer(f, e):
        if isinstance(f, types.Float) and isinstance(e, types.Integer):
            return DecimalType()
    return typer

@register_model(DecimalType)
class DecimalModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        members = [('f', types.float64),
                   ('e', types.int64),]
        models.StructModel.__init__(self, dmm, fe_type, members)
        
make_attribute_wrapper(DecimalType, 'f', 'f')
make_attribute_wrapper(DecimalType, 'e', 'e')

@lower_builtin(Decimal, types.Float, types.Integer)
def impl_decimal(context, builder, sig, args):
    typ = sig.return_type
    f, e = args
    deci = cgutils.create_struct_proxy(typ)(context, builder)
    deci.f = f
    deci.e = e
    return deci._getvalue() 
   
@unbox(DecimalType)
def unbox_interval(typ, obj, c):    
    #Convert a Decimal object to a native structure.    
    is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit)
    deci = cgutils.create_struct_proxy(typ)(c.context, c.builder)

    with ExitStack() as stack:
        f_obj = c.pyapi.object_getattr_string(obj, "f")
        with cgutils.early_exit_if_null(c.builder, stack, f_obj):
            c.builder.store(cgutils.true_bit, is_error_ptr)
        f_native = c.unbox(types.float64, f_obj)
        c.pyapi.decref(f_obj)
        with cgutils.early_exit_if(c.builder, stack, f_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        e_obj = c.pyapi.object_getattr_string(obj, "e")        
        if e_obj is not None:
            e_native = c.unbox(types.int64, e_obj)
        else:
            e_native = NativeValue(0.0, is_error=cgutils.false_bit)
        c.pyapi.decref(e_obj)
        with cgutils.early_exit_if(c.builder, stack, e_native.is_error):
            c.builder.store(cgutils.true_bit, is_error_ptr)

        deci.f = f_native.value
        deci.e = e_native.value

    return NativeValue(deci._getvalue(), is_error=c.builder.load(is_error_ptr))

@box(DecimalType)
def box_interval(typ, val, c):    
    #Convert a native structure to an Decimal object.    
    ret_ptr = cgutils.alloca_once(c.builder, c.pyapi.pyobj)
    fail_obj = c.pyapi.get_null_object()
    
    with ExitStack() as stack:
        deci = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val)
        f_obj = c.box(types.float64, deci.f)
        with cgutils.early_exit_if_null(c.builder, stack, f_obj):
            c.builder.store(fail_obj, ret_ptr)

        e_obj = c.box(types.int64, deci.e)
        with cgutils.early_exit_if_null(c.builder, stack, e_obj):
            c.pyapi.decref(f_obj)
            c.builder.store(fail_obj, ret_ptr)

        class_obj = c.pyapi.unserialize(c.pyapi.serialize_object(Decimal))
        with cgutils.early_exit_if_null(c.builder, stack, class_obj):
            c.pyapi.decref(f_obj)
            c.pyapi.decref(e_obj)
            c.builder.store(fail_obj, ret_ptr)

        # NOTE: The result of this call is not checked as the clean up
        # has to occur regardless of whether it is successful. If it
        # fails `res` is set to NULL and a Python exception is set.
        res = c.pyapi.call_function_objargs(class_obj, (f_obj, e_obj))
        c.pyapi.decref(f_obj)
        c.pyapi.decref(e_obj)
        c.pyapi.decref(class_obj)
        c.builder.store(res, ret_ptr)

    return c.builder.load(ret_ptr)

@njit
def add(self:Decimal, other:Decimal) -> Decimal:
    f1 = self.f
    e1 = self.e
    f2 = other.f
    e2 = other.e
    d = e1-e2
    if d == 0:
        f1 = f1
        summation = Decimal((f1 + f2), e1)
    elif d > 0 and d < 50:        
        f2 = f2/10**d
        summation = Decimal((f1 + f2), e1)
    elif d>=50:
        summation = Decimal(f1,e1)
    elif d < 0 and d > -50:        
        f2 = f2
        f1 = f1/10**(-d)
        summation = Decimal((f1 + f2), e2)
    elif d <=-50:
        summation = Decimal(f2,e2)    
    #print(f1,e1,f2,e2,d)
    return summation

@njit
def sub(self:Decimal, other:Decimal) -> Decimal:
    other_ = Decimal(-other.f, other.e)
    return add(self, other_)
        
@njit
def mul(self:Decimal, other:Decimal) -> Decimal:
    return(Decimal(self.f*other.f, self.e+other.e))   

@njit
def div(self:Decimal, other:Decimal) -> Decimal:
    if other.f == 0:
        raise TypeError("Divide by 0 occurred")
    else:
        other_ = Decimal(1 / other.f, -other.e)
        return mul(self, other_)

@njit
def ln(self:Decimal) -> Decimal:
    if self.f <= 0:
        raise TypeError("Log is undefined for a negative value")
    else:
        log_f = math.log(self.f)
        log_e = self.e * math.log(10)
        log = log_f + log_e
        return Decimal(log,0)

#@njit
#def float(self:Decimal) -> float:
#    return self.f*10**self.e

@njit
def int_frac(f):
    integer = int(f//1)
    fraction = f - f//1
    return integer , fraction

@njit
def len_num(N):
    count = 0
    while N!=0:
        N//=10
        count+=1
    return count

@njit
def _div_nearest(a, b):
    """Closest integer to a/b, a and b positive integers; rounds to even
    in the case of a tie.
    """
    q = a//b 
    r = a%b
    if (q % 2) == 1:
        sum = q + (2*r + 1 > b) 
    else:
        sum = q + (2*r > b)
    return sum

@njit
def bit_length_numba(x):   
    if x == 0:
        return 0    
    length = 0
    while x != 0:
        length += 1
        x >>= 1  # right shift by 1 bit
    return length

@njit
def _iexp(x, M, L=8):
    """Given integers x and M, M > 0, such that x/M is small in absolute
    value, compute an integer approximation to M*exp(x/M).  For 0 <=
    x/M <= 2.4, the absolute error in the result is bounded by 60 (and
    is usually much smaller)."""

    # Algorithm: to compute exp(z) for a real number z, first divide z
    # by a suitable power R of 2 so that |z/2**R| < 2**-L.  Then
    # compute expm1(z/2**R) = exp(z/2**R) - 1 using the usual Taylor
    # series
    #
    #     expm1(x) = x + x**2/2! + x**3/3! + ...
    #
    # Now use the identity
    #
    #     expm1(2x) = expm1(x)*(expm1(x)+2)
    #
    # R times to compute the sequence expm1(z/2**R),
    # expm1(z/2**(R-1)), ... , exp(z/2), exp(z).
    x = int(x)
    M = int(M)
    # Find R such that x/2**R/M <= 2**-L
    R = bit_length_numba((x<<L)//M)

    # Taylor series.  (2**L)**T > M
    T = -int(-10*len_num(M)//(3*L))
    y = _div_nearest(x, T)
    Mshift = M<<R
    for i in range(T-1, 0, -1):
        y = _div_nearest(x*(Mshift + y), Mshift * i)

    # Expansion
    for k in range(R-1, -1, -1):
        Mshift = M<<(k+2)
        y = _div_nearest(y*(y+Mshift), Mshift)    
    return M+y

@njit
def _dexp(c, e):
    """Compute an approximation to exp(c*10**e), with p decimal places of
    precision.

    Returns integers d, f such that:

      10**(p-1) <= d <= 10**p, and
      (d-1)*10**f < exp(c*10**e) < (d+1)*10**f

    In other words, d*10**f is an approximation to exp(c*10**e) with p
    digits of precision, and with an error in d of at most 1.  This is
    almost, but not quite, the same as the error being < 1ulp: when d
    = 10**(p-1) the error could be up to 10 ulp."""   
    p = 6
    # we'll call iexp with M = 10**(p+2), giving p+3 digits of precision
    p += 2

    # compute log(10) with extra precision = adjusted exponent of c*10**e
    extra = max(0, e + 6)
    q = p + extra

    # compute quotient c*10**e/(log(10)) = c*10**(e+q)/(log(10)*10**q),
    # rounding down
    shift = e+q
    if shift >= 0:
        cshift = c*10**shift
    else:
        cshift = c // 10**-shift
    quot = cshift // math.floor((10**q)*math.log(10))
    rem = cshift % math.floor((10**q)*math.log(10))

    # reduce remainder back to original precision
    rem = _div_nearest(rem, 10**extra)

    # error in result of _iexp < 120;  error after division < 0.62
    return _div_nearest(_iexp(rem, 10**p,8), 1000), quot - p + 3

@njit
def exp(self:Decimal)->Decimal:    
    f = int(self.f * 10000)
    e = self.e - 4   
    coeff, exp = _dexp(f, e)
    exp = int(exp)
    ans = Decimal(coeff/1e5,exp+5)
    return ans 

@njit
def f2D(f,e=0):
    f=float(f)
    ans = Decimal(f,e)    
    return ans

@njit
def D2f(D):
    ans = D.f*10**D.e    
    return ans