IntEnum as optional argument

Hey Numba team,

I’m encountering an issue while trying to use an IntEnum as an optional argument in a jitted function. Whenever I attempt to pass an IntEnum member as an optional argument, I get a TypeError: '>' not supported between instances of 'Conversion' and 'NoneType'. Explicit casting into an integer doesn’t seem to be supported.

from enum import IntEnum
import numba as nb

class Number(IntEnum):
    ONE = 1
    TWO = 2
    THREE = 3

@nb.njit(['i8(i8, i8)', 'i8(i8, Omitted(None))'])
def add(a, b=None):
    b = b or Number.ONE
    return a+b

print('add(1, 1) => ', add(1, 1))
print('add(1) => ', add(1))
print('add(1, Number.ONE) => ', add(1, Number.ONE))
# add(1, 1) ->  2
# add(1) ->  2
# TypeError: '>' not supported between instances of 'Conversion' and 'NoneType'

Surprisingly, if you’ve previously called a function that uses IntEnum as an ordinary argument, you can then use the IntEnum as an optional argument, too.

from enum import IntEnum
import numba as nb

class Number(IntEnum):
    ONE = 1
    TWO = 2
    THREE = 3

@nb.njit(['i8(i8, i8)'])
def add_dummy(a, b):
    return a+b

add_dummy(1, 1)
add_dummy(1, Number.ONE)

@nb.njit(['i8(i8, i8)', 'i8(i8, Omitted(None))'])
def add(a, b=None):
    b = b or Number.ONE
    return a+b

print('add(1, 1) => ', add(1, 1))
print('add(1) => ', add(1))
print('add(1, Number.ONE) => ', add(1, Number.ONE))
# add(1, 1) ->  2
# add(1) ->  2
# add(1, Number.ONE) =>  2

Is there a way to provide sufficient type hints and prevent the TypeError?

Here is a version with type inference:

from enum import IntEnum
import numba as nb

class Number(IntEnum):
    ONE = 1
    TWO = 2
    THREE = 3

@nb.njit
def add(a, b=None):
    if b is None:
        bint = Number.ONE
    else:
        bint = b
    return a + bint

print('add(1, 1) => ', add(1, 1))
print('add(1) => ', add(1))
print('add(1, Number.ONE) => ', add(1, Number.ONE))
print(add.signatures)
# add(1, 1) =>  2
# add(1) =>  2
# add(1, Number.ONE) =>  2
# [(int64, int64), (int64, omitted(default=None)), (int64, IntEnum<int64>(Number))]

Here is a version with explicit signatures:

from enum import IntEnum
import numba as nb
import numba.types as nbt

class Number(IntEnum):
    ONE = 1
    TWO = 2
    THREE = 3

# Register the IntEnum with Numba.
NumberType = nb.typeof(Number.ONE)

@nb.njit(['i8(i8, i8)', 'i8(i8, Omitted(None))', nbt.i8(nbt.i8, NumberType)])
def add(a, b=None):
    if b is None:
        bint = Number.ONE
    else:
        bint = b
    return a + bint

print('add(1, 1) => ', add(1, 1))
print('add(1) => ', add(1))
print('add(1, Number.ONE) => ', add(1, Number.ONE))

Unfortunately, these implementations result in an additional function signature which might be avoidable.

This seems to work as intended.

from enum import IntEnum
import numba as nb

class Number(IntEnum):
    ONE = 1
    TWO = 2
    THREE = 3

# Register Numba type for your IntEnum class
@nb.extending.typeof_impl.register(Number)
def typeof_number(obj, c):
    return nb.types.int64

@nb.njit
def add(a, b=None):
    b = b or Number.ONE
    return a + b

print('add(1, 1) => ', add(1, 1))
print('add(1) => ', add(1))
print('add(1, Number.ONE) => ', add(1, Number.ONE))
print(add.signatures)
# add(1, 1) =>  2
# add(1) =>  2
# add(1, Number.ONE) =>  2
# [(int64, int64), (int64, omitted(default=None))]