Hi everyone,
I need to use the numpy
function atleast_1d
inside a Numba-jitted function.
I had problem in using it when the argument of np.atleast_1d
used inside the function is a scalar, as pointed out here https://github.com/numba/numba/issues/4202.
I tried to use the @overload
decorator as in this MWE:
import numba as nb
import numpy as np
# need the following lines to avoid NumbaPendingDeprecationWarning
import os
os.environ["NUMBA_CAPTURED_ERRORS"]='new_style'
# Overload np.atleast_1d
from numba import types
from numba.extending import overload
@overload(np.atleast_1d)
def ol_atleast_1d(x):
if x in types.number_domain:
return lambda x: np.array([x])
return lambda x: np.atleast_1d(x)
# Write function that uses np.atleast_1d:
@nb.njit
def func(arg):
_arg = np.atleast_1d(arg)
return np.sum(_arg)
The function works fine when arg
is a scalar, but fails when the input is an array:
a = 1.0
print(func(1.0)) # 1: ok
b = np.array([1.0,1.0,1.0])
func(b) # fails
The error is quite long, but it ends with
RecursionError: maximum recursion depth exceeded in comparison
Any idea on where the problem can be?