TypingError: Cannot request literal type in overload

Hey Numba team,

I want to write a function to squeeze the dimensions of a 2D array into 2D, 1D, or 0D array using Numba overload.
Unfortunately I have some issues with typing Literals. I get a “TypingError: Cannot request literal type.”
Here’s my code:

import numpy as np
from numba import njit
from numba import types
from numba import literally
from numba.extending import overload
from numba.core.errors import TypingError

def squeeze_2d(arr, axis):
    pass

@overload(squeeze_2d, prefer_literal=True)
def squeeze_2d_imp(arr, axis):

    if not isinstance(arr, types.Array):
        raise TypingError("'arr' must be of type Array.")

    if arr.ndim != 2:
        raise ValueError('Array dimensions not supported.')

    # force axis to be Literal type
    if not isinstance(axis, types.Literal):
        def imp_to_literal(arr, axis):
            return squeeze_2d(arr, literally(axis))
        return imp_to_literal

    if isinstance(axis, types.Literal):

        SQUEEZE_NONE = -1
        SQUEEZE_COL = 0
        SQUEEZE_ROW = 1
        SQUEEZE_ALL = 2

        axis_val = axis.literal_value

        if axis_val == SQUEEZE_NONE:  # -1
            return lambda arr, axis: arr
        elif axis_val == SQUEEZE_COL:  # 0
            return lambda arr, axis: arr[:, 0]
        elif axis_val == SQUEEZE_ROW:  # 1
            return lambda arr, axis: arr[0, :]
        elif axis_val == SQUEEZE_ALL:  # 2
            return lambda arr, axis: np.asarray(arr[0, 0])
        else:
            raise ValueError("Array dimensions not supported.")

@njit
def squeeze_2d_axis(shape):
    """
    Find axis to squeeze a 2D array by its shape.

    Example:
        shapes = [(3,4), (5,1), (1,5), (1, 1)]
        for shape in shapes:
            axis = squeeze_2d_axis(shape)
            print(f'shape: {shape}, squeeze axis: {axis}')
    """
    if len(shape) != 2:
        raise ValueError("Array dimensions not supported.")

    SQUEEZE_NONE = -1
    SQUEEZE_COL = 0
    SQUEEZE_ROW = 1
    SQUEEZE_ALL = 2
    nrows, ncols = shape

    if (nrows > 1) and (ncols > 1):
        return SQUEEZE_NONE
    elif (nrows > 1) and (ncols == 1):
        return SQUEEZE_COL
    elif (nrows == 1) and (ncols > 1):
        return SQUEEZE_ROW
    elif (nrows == 1) and (ncols == 1):
        return SQUEEZE_ALL
    else:
        raise ValueError("Array shape not supported.")

When I try to execute do_squeeze_2d(arr) where arr is a 2D array, I’m encountering the following error:

@njit
def do_squeeze_2d(arr):
    axis = squeeze_2d_axis(arr.shape)
    return squeeze_2d(arr, axis)

arr = np.ones((1, 4))
print(arr, '\n', do_squeeze_2d(arr))
# expected result: array([1.,1.,1.,1.])

Traceback (most recent call last):

  File ~/miniconda3/envs/dev/lib/python3.12/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/Dokumente/python/snippets/numba/overload_squeeze.py:102
    print(arr, '\n', do_squeeze_2d(arr))

  File ~/miniconda3/envs/dev/lib/python3.12/site-packages/numba/core/dispatcher.py:468 in _compile_for_args
    error_rewrite(e, 'typing')

  File ~/miniconda3/envs/dev/lib/python3.12/site-packages/numba/core/dispatcher.py:409 in error_rewrite
    raise e.with_traceback(None)

TypingError: Cannot request literal type.

File ".../overload_squeeze.py", line 98:
def do_squeeze_2d(arr):
    <source elided>
    return squeeze_2d(arr, axis)

 ^

During: resolving callee type: Function(<function squeeze_2d at 0x7f9deee53420>)
During: typing of call at .../overload_squeeze.py (98)

The error seems to be related to the use of types.Literal in the overload function.
Could someone please help me fix this typing error?
Thank you in advance for your help!

Regards, Oyibo

1 Like

The problem seems to be that Numba is not able to infer a Literal type from squeeze_2d_axis.
Here is a simplified example:

import numpy as np
from numba import njit
from numba.extending import overload, SentryLiteralArgs

def do_this(arr, val):
    return arr + val

@overload(do_this)
def ov_do_this(arr, val):
    SentryLiteralArgs(['val']).for_function(ov_do_this).bind(arr, val)
    print(val)
    print(val.literal_value)
    return lambda arr, val: arr + val

@njit
def foo(arr, val):
    return do_this(arr, val)

@njit
def compute(arr):
    return len(arr)

@njit
def bar(a):
    b = compute(a)
    return do_this(a, b)

This works:

arr = np.zeros((2, 2))
val = 5
print(foo(arr, val))    # <= val is static of type Literal[int]

# Literal[int](5)
# 5
# [[5. 5.]
#  [5. 5.]]

This doesn’t work:

print(bar(arr))         # <= val is computed/dynamic of type int64

# TypingError: Cannot request literal type.

# File "../../tmp/ipykernel_723744/3630181384.py", line 22:
# <source missing, REPL/exec in use?>

# During: resolving callee type: Function(<function do_this at 0x7ff325804ea0>)
# During: typing of call at /tmp/ipykernel_723744/3630181384.py (22)