Hey Numba team,
I want to write a function to squeeze the dimensions of a 2D array into 2D, 1D, or 0D array using Numba overload.
Unfortunately I have some issues with typing Literals. I get a “TypingError: Cannot request literal type.”
Here’s my code:
import numpy as np
from numba import njit
from numba import types
from numba import literally
from numba.extending import overload
from numba.core.errors import TypingError
def squeeze_2d(arr, axis):
pass
@overload(squeeze_2d, prefer_literal=True)
def squeeze_2d_imp(arr, axis):
if not isinstance(arr, types.Array):
raise TypingError("'arr' must be of type Array.")
if arr.ndim != 2:
raise ValueError('Array dimensions not supported.')
# force axis to be Literal type
if not isinstance(axis, types.Literal):
def imp_to_literal(arr, axis):
return squeeze_2d(arr, literally(axis))
return imp_to_literal
if isinstance(axis, types.Literal):
SQUEEZE_NONE = -1
SQUEEZE_COL = 0
SQUEEZE_ROW = 1
SQUEEZE_ALL = 2
axis_val = axis.literal_value
if axis_val == SQUEEZE_NONE: # -1
return lambda arr, axis: arr
elif axis_val == SQUEEZE_COL: # 0
return lambda arr, axis: arr[:, 0]
elif axis_val == SQUEEZE_ROW: # 1
return lambda arr, axis: arr[0, :]
elif axis_val == SQUEEZE_ALL: # 2
return lambda arr, axis: np.asarray(arr[0, 0])
else:
raise ValueError("Array dimensions not supported.")
@njit
def squeeze_2d_axis(shape):
"""
Find axis to squeeze a 2D array by its shape.
Example:
shapes = [(3,4), (5,1), (1,5), (1, 1)]
for shape in shapes:
axis = squeeze_2d_axis(shape)
print(f'shape: {shape}, squeeze axis: {axis}')
"""
if len(shape) != 2:
raise ValueError("Array dimensions not supported.")
SQUEEZE_NONE = -1
SQUEEZE_COL = 0
SQUEEZE_ROW = 1
SQUEEZE_ALL = 2
nrows, ncols = shape
if (nrows > 1) and (ncols > 1):
return SQUEEZE_NONE
elif (nrows > 1) and (ncols == 1):
return SQUEEZE_COL
elif (nrows == 1) and (ncols > 1):
return SQUEEZE_ROW
elif (nrows == 1) and (ncols == 1):
return SQUEEZE_ALL
else:
raise ValueError("Array shape not supported.")
When I try to execute do_squeeze_2d(arr) where arr is a 2D array, I’m encountering the following error:
@njit
def do_squeeze_2d(arr):
axis = squeeze_2d_axis(arr.shape)
return squeeze_2d(arr, axis)
arr = np.ones((1, 4))
print(arr, '\n', do_squeeze_2d(arr))
# expected result: array([1.,1.,1.,1.])
Traceback (most recent call last):
File ~/miniconda3/envs/dev/lib/python3.12/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
exec(code, globals, locals)
File ~/Dokumente/python/snippets/numba/overload_squeeze.py:102
print(arr, '\n', do_squeeze_2d(arr))
File ~/miniconda3/envs/dev/lib/python3.12/site-packages/numba/core/dispatcher.py:468 in _compile_for_args
error_rewrite(e, 'typing')
File ~/miniconda3/envs/dev/lib/python3.12/site-packages/numba/core/dispatcher.py:409 in error_rewrite
raise e.with_traceback(None)
TypingError: Cannot request literal type.
File ".../overload_squeeze.py", line 98:
def do_squeeze_2d(arr):
<source elided>
return squeeze_2d(arr, axis)
^
During: resolving callee type: Function(<function squeeze_2d at 0x7f9deee53420>)
During: typing of call at .../overload_squeeze.py (98)
The error seems to be related to the use of types.Literal in the overload function.
Could someone please help me fix this typing error?
Thank you in advance for your help!
Regards, Oyibo