How to use types.Integer for indexing list

Hello,

I am trying to overload numpy functions which need the creation of a tuple based on the arguments, e.g. np.swapaxes. In my case I want to swap two elements of range(n). My problem is, that I am not able to use types.Integer for array indexing.

Here is a (not so minimal) mwe.

import numpy as np
from numba import njit
from numba.core import types
from numba.core.extending import overload


@overload(np.swapaxes)
def numpy_swapaxes(arr, axis1, axis2):
    if not isinstance(axis1, (int, types.Integer)):
        raise errors.TypingError('The second argument "axis1" must be an '
                                 'integer')
    if not isinstance(axis2, (int, types.Integer)):
        raise errors.TypingError('The third argument "axis2" must be an '
                                 'integer')
    if not isinstance(arr, types.Array):
        raise errors.TypingError('The first argument "arr" must be an array')
 
    # create tuple list, this is where I fail
    axes_list = list(range(arr.ndim))
    axes_list[axis1] = axis2  # <- types.Integer can not be used for accessing
    axes_list[axis2] = axis1
    axes_list = tuple(axes_list)
    
    def impl(arr, axis1, axis2): 
        return arr[:]  # some function needing a tuple goes here

    return impl


@njit
def swapaxes(arr, axis1, axis2):
	return np.swapaxes(arr, axis1, axis2)


arr = np.arange(27).reshape(3, 3, 3)
axis1, axis2 = 0, 2

swapaxes(arr, axis1, axis2)

which raises following error

No implementation of function Function(<function swapaxes at 0x7ff7e90221f0>) found for signature:
 
 >>> swapaxes(array(int64, 3d, C), int64, int64)
 
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'numpy_swapaxes': File: ../../../../Documents/Coding/test/swapaxes.py: Line 7.
    With argument(s): '(array(int64, 3d, C), int64, int64)':
   Rejected as the implementation raised a specific error:
     TypeError: list indices must be integers or slices, not Integer
  raised from swapaxes.py:20

During: resolving callee type: Function(<function swapaxes at 0x7ff7e90221f0>)
During: typing of call at swapaxes.py (32)


File "swapaxes.py", line 32:
def swapaxes(arr, axis1, axis2):
	return np.swapaxes(arr, axis1, axis2)
 ^

Can anyone tell me how to avoid this?

Thanks in advance :slight_smile:
Daniel

Hi @braniii,

The issue in the above is that whilst arr.ndim is a constant and part of the numba.types.Array type, types.Integers don’t carry any such information. As the “create tuple list” part is being done in the typing scope (see walk through/guide here for help) the only information available is that “it’s an Integer” and “the array has some constant ndim”, which is insufficient information to create the compile time constant that the impl is trying to close over as the indexes for manipulating the range are unknown. Two ways around this, I’d go with the second as it’s more general…

  1. Use numba.literally support to force the requirement of literal integer values (numba.types.IntegerLiteral type) in the typing scope.
import numpy as np
from numba import njit, literally
from numba.core import types, errors
from numba.core.extending import overload


@overload(np.swapaxes)
def numpy_swapaxes(arr, axis1, axis2):
    if not isinstance(axis1, (int, types.IntegerLiteral)):
        return lambda arr, axis1, axis2: literally(axis1), literally(axis2)
    if not isinstance(axis2, (int, types.IntegerLiteral)):
        return lambda arr, axis1, axis2: literally(axis1), literally(axis2)
    if not isinstance(arr, types.Array):
        raise errors.TypingError('The first argument "arr" must be an array')

    axes_list = list(range(arr.ndim))
    axes_list[axis1.literal_value] = axis2.literal_value
    axes_list[axis2.literal_value] = axis1.literal_value
    axes_list = tuple(axes_list)

    def impl(arr, axis1, axis2):
        return arr[:]  # some function needing a tuple goes here

    return impl


@njit
def swapaxes_ok(arr):
    return np.swapaxes(arr, 1, 2)

arr = np.arange(27).reshape(3, 3, 3)

swapaxes_ok(arr)


@njit
def swapaxes_fail(arr, ax1, ax2):
    return np.swapaxes(arr, ax1, ax2) # not literal axis, will fail typing

axis1, axis2 = 0, 2
swapaxes_fail(arr, axis1, axis2)
  1. Use “unsafe” tuple manipulation functions to mutate the tuple at runtime, this doesn’t require compile time constants/literal values etc.
import numpy as np
from numba import njit
from numba.core import types
from numba.core.extending import overload
from numba.cpython.unsafe.tuple import tuple_setitem

@overload(np.swapaxes)
def numpy_swapaxes(arr, axis1, axis2):
    if not isinstance(axis1, (int, types.Integer)):
        raise errors.TypingError('The second argument "axis1" must be an '
                                 'integer')
    if not isinstance(axis2, (int, types.Integer)):
        raise errors.TypingError('The third argument "axis2" must be an '
                                 'integer')
    if not isinstance(arr, types.Array):
        raise errors.TypingError('The first argument "arr" must be an array')

    # create tuple for manipulation
    axes_list = list(range(arr.ndim))
    axes_list = tuple(axes_list)

    def impl(arr, axis1, axis2):
        # chain the manipulation
        print("start axes_list", axes_list)
        mutated_axes = tuple_setitem(axes_list, axis1, axis2)
        mutated_axes = tuple_setitem(mutated_axes, axis2, axis1)
        print("mutated axes", mutated_axes)
        return arr[:]  # some function needing a tuple goes here
    return impl


@njit
def swapaxes(arr, axis1, axis2):
        return np.swapaxes(arr, axis1, axis2)


arr = np.arange(27).reshape(3, 3, 3)
axis1, axis2 = 0, 2

swapaxes(arr, axis1, axis2)

Hope this helps?

1 Like

The second one was what I looked for. Thank you so much :slight_smile:

No problem, glad it is useful!