Overloading problem

I was having a weird problem that I was able to simplify to this example:

import numba as nb
import numpy as np
from numpy import testing

@nb.njit
def in_range(arr, mn, mx):
    if mn is None and mx is None:
        return np.full(arr.shape, True)

    if mx is None:
        return mn <= arr

    if mn is None:
        return arr < mx

    return np.logical_and(mn <= arr, arr < mx)

print(nb.version_info)
arr = np.asarray((1, 4, 5, 8))

# 1
testing.assert_array_equal(in_range(arr, None, None),
                           np.asarray([True, True, True, True]))

# 2
testing.assert_array_equal(in_range(arr, 4, None),
                           np.asarray([False, True, True, True]))

# 3
testing.assert_array_equal(in_range(arr, None, 4),
                           np.asarray([True, False, False, False]))

# 4
testing.assert_array_equal(in_range(arr, 4, 7),
                           np.asarray([False, True, True, False]))

and I am getting:

No implementation of function Function(<built-in function le>) found for signature:
 
 >>> le(none, array(int64, 1d, C))

This example works fine if I comment @nb.njit. It seems that #3 is trying to reuse an implementation with the wrong types. I am using Numba 0.53

I think you may be hitting a limitation of the branch pruner - I suspect the branch pruner doesn’t remove the if mx is None branch, which would be required for this to compile successfully. It may be that you would have to have separate implementations of the function (e.g. using @overload or @generated_jit) to avoid Numba attempting to compile code in “invalid” branches.

@stuartarchibald or @sklam may have some more insight here.

This is a bit of a roundabout way at getting what you want, but it does work. In this case I used overload to create an implementation of in_range that dispatches based on type and has four different implementation depending on the types of the arguments. This can probably be cleaned and shortened a bit, but I will post it anyway as step in the right direction.

import numba as nb
import numpy as np
from numpy import testing
from numba.extending import overload


def njit_in_range(arr, mn, mx):
    pass


@overload(njit_in_range)
def ol_in_range(arr, mn, mx):
    if isinstance(mn, nb.types.NoneType) and isinstance(mx, nb.types.NoneType):
        def imp(arr, mn, mx):
            return np.full(arr.shape, True)

    elif isinstance(mx, nb.types.NoneType):
        def imp(arr, mn, mx):
            return mn <= arr

    elif isinstance(mn, nb.types.NoneType):
        def imp(arr, mn, mx):
            return arr < mx

    else:
        def imp(arr, mn, mx):
            return np.logical_and(mn <= arr, arr < mx)

    return imp


@nb.njit
def in_range(arr, mn, mx):
    return njit_in_range(arr, mn, mx)


print(nb.version_info)
arr = np.asarray((1, 4, 5, 8))

testing.assert_array_equal(in_range(arr, None, None),
                           np.asarray([True, True, True, True]))

testing.assert_array_equal(in_range(arr, 4, None),
                           np.asarray([False, True, True, True]))

testing.assert_array_equal(in_range(arr, None, 4),
                           np.asarray([True, False, False, False]))

testing.assert_array_equal(in_range(arr, 4, 7),
                           np.asarray([False, True, True, False]))

Hope it helps!

Note: I used the instructions here: A guide to using @overload — Numba 0.52.0.dev0+274.g626b40e-py3.7-linux-x86_64.egg documentation

Thanks for the suggestion, this solves the problem!

Great, happy it worked out!