RecursionError when overloading np.atleast_1d

Hi everyone,

I need to use the numpy function atleast_1d inside a Numba-jitted function.
I had problem in using it when the argument of np.atleast_1d used inside the function is a scalar, as pointed out here https://github.com/numba/numba/issues/4202.

I tried to use the @overload decorator as in this MWE:

import numba as nb
import numpy as np

# need the following lines to avoid NumbaPendingDeprecationWarning 
import os
os.environ["NUMBA_CAPTURED_ERRORS"]='new_style'

# Overload np.atleast_1d
from numba import types
from numba.extending import overload

@overload(np.atleast_1d)
def ol_atleast_1d(x):
    if x in types.number_domain:
        return lambda x: np.array([x])
    return lambda x: np.atleast_1d(x)

# Write function that uses np.atleast_1d:
@nb.njit
def func(arg):
    _arg = np.atleast_1d(arg)
    return np.sum(_arg)

The function works fine when arg is a scalar, but fails when the input is an array:


a = 1.0
print(func(1.0)) # 1: ok

b = np.array([1.0,1.0,1.0])
func(b) # fails

The error is quite long, but it ends with
RecursionError: maximum recursion depth exceeded in comparison

Any idea on where the problem can be?

Hey @mtagliazucchi ,

it seems that your code is stuck in a recursion.
Can you try this?

import numpy as np
import numba as nb
from numba import types
from numba.extending import overload

@overload(np.atleast_1d)
def ol_atleast_1d(x):
    if x in types.number_domain:
        return lambda x: np.array([x])
    elif isinstance(x, types.Array):
        if x.ndim == 0:
            return lambda x: np.array([x])
        else:
            return lambda x: x

@nb.njit
def func(arg):
    return np.atleast_1d(arg).ndim

print(func(1))  # scalar
print(func(np.array(1)))   # 0D-array
print(func(np.array([1])))  # 1D-array
print(func(np.array([[1],[1]])))  # 2D-array
#1
#1
#1
#2

This works! Thank you so much