Hello,
I am trying to overload numpy functions which need the creation of a tuple based on the arguments, e.g. np.swapaxes
. In my case I want to swap two elements of range(n)
. My problem is, that I am not able to use types.Integer
for array indexing.
Here is a (not so minimal) mwe.
import numpy as np
from numba import njit
from numba.core import types
from numba.core.extending import overload
@overload(np.swapaxes)
def numpy_swapaxes(arr, axis1, axis2):
if not isinstance(axis1, (int, types.Integer)):
raise errors.TypingError('The second argument "axis1" must be an '
'integer')
if not isinstance(axis2, (int, types.Integer)):
raise errors.TypingError('The third argument "axis2" must be an '
'integer')
if not isinstance(arr, types.Array):
raise errors.TypingError('The first argument "arr" must be an array')
# create tuple list, this is where I fail
axes_list = list(range(arr.ndim))
axes_list[axis1] = axis2 # <- types.Integer can not be used for accessing
axes_list[axis2] = axis1
axes_list = tuple(axes_list)
def impl(arr, axis1, axis2):
return arr[:] # some function needing a tuple goes here
return impl
@njit
def swapaxes(arr, axis1, axis2):
return np.swapaxes(arr, axis1, axis2)
arr = np.arange(27).reshape(3, 3, 3)
axis1, axis2 = 0, 2
swapaxes(arr, axis1, axis2)
which raises following error
No implementation of function Function(<function swapaxes at 0x7ff7e90221f0>) found for signature:
>>> swapaxes(array(int64, 3d, C), int64, int64)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'numpy_swapaxes': File: ../../../../Documents/Coding/test/swapaxes.py: Line 7.
With argument(s): '(array(int64, 3d, C), int64, int64)':
Rejected as the implementation raised a specific error:
TypeError: list indices must be integers or slices, not Integer
raised from swapaxes.py:20
During: resolving callee type: Function(<function swapaxes at 0x7ff7e90221f0>)
During: typing of call at swapaxes.py (32)
File "swapaxes.py", line 32:
def swapaxes(arr, axis1, axis2):
return np.swapaxes(arr, axis1, axis2)
^
Can anyone tell me how to avoid this?
Thanks in advance
Daniel