Catch overflow in integer multiplication

I am using numba and I would like to know if an overflow has occurred when I multiply two integers. Say the integers are positive for simplicity.

I have written the following function to try and achieve this:

from numba import njit
import numpy as np

@njit
def safe_mul(a, b):
    c = a * b

    print('a: ', a)
    print('b: ', b)
    print('c: ', c)
    
    print('c//a: ', c//a)
    print('0//a: ', 0//a)

    if c // a != b or c // b != a:
        # do something else or raise error
        raise ValueError()
    return c


@njit
def safe_mul_2(a, b):
    if (np.log2(np.abs(a)) + np.log2(np.abs(b))) >= 63:
        # do something else or raise error
        raise ValueError()
        
    return a * b

print(safe_mul(2**21, 2**51))

The code prints:

a:  2097152
b:  2251799813685248
c:  0
c//a:  2251799813685248
0//a:  0
0

safe_mul does not catch the overflow when 2**21 and 2**51 are passed in. Perhaps numba is compiling out the integer division since it knows c has just been multiplied by what it is being divided by? I am not sure about this since when you enter something where the arguments are not both powers of 2 then the error is caught.

safe_mul_2 does catch the error, and is surprisingly not much slower than safe_mul (when the prints are removed). I would like to know what is happening in safe_mul and if something that is faster than safe_mul_2 can be written.

Dear @patrick

Numba does not remove the operation. You can check this by inspecting the generated intermediate representation:

safe_mul.inspect_llvm()

However, an optimization is performed that results in instructions that do not show the overflow. You can get your example to work by forcing Numba to generate the intendet intermediate representation:

@intrinsic
def compare(typingctx, a, b):
    def codegen(context, builder, signature, args):
        a, b = args
        prod = builder.mul(a, b)
        quot = builder.sdiv(prod, a)
        ret = builder.icmp_signed("!=", quot, b)
        return ret
    return nb.boolean(a, b), codegen

Call it as follows:

@njit
def safe_mul(a, b):
    if compare(a, b) or compare(b, a):
        raise ValueError()
    ...

It’s probably no longer faster to do it this way.
Regarding your second function, are you sure it’s correct? E.g. safe_mul2(1, np.iinfo(np.int64).max) should not raise the error but does.

I hope this clears things up a bit. If you’re not familiar with some of the things shown, let me know.

Edit:
I had another look at your problem and it turns out LLVM has special support for your needs. For example, it provides signed integer multiplication with overflow checking: llvm.smul.with.overflow.i64. In a function this would look like this:

from llvmlite import ir
from numba import njit, types
from numba.extending import intrinsic

@intrinsic
def mul_with_overflow_check(typingctx, a, b):
    def codegen(context, builder, signature, args):
        int_t = ir.IntType(64)
        bool_t = ir.IntType(1)
        result_t = ir.LiteralStructType([int_t, bool_t])
        fnty = ir.FunctionType(result_t, [int_t, int_t])
        fn = ir.Function(builder.module, fnty, name='llvm.smul.with.overflow.i64')
        
        prod_and_overflowed = builder.call(fn, args)
        prod = builder.extract_value(prod_and_overflowed, 0)
        overflowed = builder.extract_value(prod_and_overflowed, 1)
        with builder.if_then(overflowed):
            context.call_conv.return_user_exc(builder, OverflowError, ("Multiplication overflowed",))   
        return prod

    return types.int64(a, b), codegen

@njit
def foo(a, b):
    return mul_with_overflow_check(a, b)

Note that this applies to signed 64-bit integers. Other cases can be implemented equivalently. Also, it is probably the fastest you can get.

1 Like

Thanks so much @sschaer, this is very helpful!

Something I am hoping to do is to be able to follow another branch if an overflow occurs. A branch that is not simply raising an error in the example as you have given. E.g. if mul_with_overflow_check could return (prod, overflowed) so I could easily condition on the overflowed variable downstream. I have tried to modify the code you have written to achieve this but to no avail.

Can your code be modified to achieve this?

With respect to your comment:

Indeed, I did not write it very carefully and it was just supposed to be an example of something that would raise an error every time an overflow occurred (so it is safe in some sense) not necessarily that it would only raise an error when an overflow occurred. I should perhaps have mentioned this in the original post.

Hi @patrick

Of course, this is possible:

from llvmlite import ir
from numba import njit, types
from numba.extending import intrinsic

@intrinsic
def mul_with_overflow_check(typingctx, a, b):
    def codegen(context, builder, signature, args):
        int_t = ir.IntType(64)
        bool_t = ir.IntType(1)
        result_t = ir.LiteralStructType([int_t, bool_t])
        fnty = ir.FunctionType(result_t, [int_t, int_t])
        fn = ir.Function(builder.module, fnty, name='llvm.smul.with.overflow.i64')
        
        prod_and_overflowed = builder.call(fn, args)
        return prod_and_overflowed

    return types.Tuple([types.int64, types.boolean])(a, b), codegen

@njit
def foo(a, b):
    prod, flag = mul_with_overflow_check(a, b)

This returns a tuple of the product of a and b and a flag that is True if the multiplication results in an overflow or False otherwise.

Thanks so much @sschaer!

A couple of final things:

  1. Where did you find the function name "llvm.smul.with.overflow.i64" and others like it? (I can infer the names which act on other types).
  2. I asked the same question on stack overflow: here if you answer it there I can accept it and others are more likely to find it I think.

No worries if you don’t get back to this, it’s already been very helpful and illuminating.

Hi @patrick

  1. you can find all LLVM instrinsics here. But I should have better checked the llvmlite documentation first, because they already support the llvm.smul.with.overflow.i64, which simplifies the implementation.
  2. i am not active on stack overflow. But feel free to copy the solution and link to this thread. You can post the implementation below. It works with any integer types and directly uses the method from llvmlite.
import numpy as np
from llvmlite import ir
from numba import njit, types
from numba import TypingError
from numba.extending import intrinsic

@intrinsic
def mul_with_overflow(typingctx, a, b):
    if not (isinstance(a, types.Integer) and isinstance(b, types.Integer)):
        raise TypingError("both arguments must be integers")
    if a.signed != b.signed:
        raise TypingError("can only multiply integers of equal signedness")
    
    if a.signed:
        ext = lambda builder, a, b: builder.sext(a, b)
        mul = lambda builder, a, b: builder.smul_with_overflow(a, b)
    else:
        ext = lambda builder, a, b: builder.zext(a, b)
        mul = lambda builder, a, b: builder.umul_with_overflow(a, b)

    retint_ty = max(a, b, key=lambda ty: ty.bitwidth)
    sig = types.Tuple([retint_ty, types.boolean])(a, b)
    
    def codegen(context, builder, signature, args):
        int_ty = context.get_value_type(retint_ty)
        a = ext(builder, args[0], int_ty)
        b = ext(builder, args[1], int_ty)
        prod_and_flag = mul(builder, a, b)
        return prod_and_flag
    
    return sig, codegen

@njit
def foo(a, b):
    return mul_with_overflow(a, b)

foo(np.int8(1), np.int32(-2)) # returns int32
foo(np.uint8(1), np.uint16(2)) # returns uint16
foo(np.int32(2), np.int64(np.iinfo(np.int64).max)) # overflow
foo(np.uint64(1), np.int64(2)) # error
foo(np.uint64(1), np.float32(2)) # error