Need help converting Numba types to Pyhton types

Hello, I am submitting this draft PR because I need help understanding the error message.
I took on the task of trying to implement the overload for numpy.vecdot inside numba.

Starting at Numpy Supported Features (

https://numba.readthedocs.io/en/stable/reference/numpysupported.html#

), I figured out that Numba currently has ufuncs to support Numpy behavior for sum and multiplication.I also found that the following formula in Numpy behaves identically to the numpy.vecdot function:sum( moveaxis( arr1, axis1, -1).conjugate() * moveaxis( arr2, axis2, -1), axis = -1 )

This inspired the first version of my attempted implementation:

def np_vecdot_impl(x1, x2, axis=-1, axes=None):
    if axes is None:
        axis1, axis2 = axis, axis
    else:
        axis1, axis2 = axes[0], axes[1]
    #normalize axis1 and axis2
    left = np.moveaxis(x1, axis1, -1).conjugate()
    right = np.moveaxis(x2, axis2, -1)
    product = np.multiply(left, right)
    result = np.sum(product, axis=-1)
    return result

Unfortunately, Numba’s moveaxis implementation returns a slicing error with the following traceback:

  File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 912, in np_vecdot_axis
    return np_vecdot_impl(x1, x2, axis=axis)
  File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 1029, in np_vecdot_impl
    left = np.moveaxis(x1, axis1, -1).conjugate()
  File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1503, in moveaxis
    source = normalize_axis_tuple(source, a.ndim, 'source')
  File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1435, in normalize_axis_tuple
    axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
  File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1435, in <listcomp>        
    axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
  File "D:\Code Projects\Open Source\numba\numba\core\types\abstract.py", line 189, in __getitem__
    ndim, layout = self._determine_array_spec(args)
  File "D:\Code Projects\Open Source\numba\numba\core\types\abstract.py", line 214, in _determine_array_spec
    raise KeyError(f"Can only index numba types with slices with no start or stop, got {args}.")
KeyError: 'Can only index numba types with slices with no start or stop, got 0.'

Seeing that there was an issue with the moveaxis() function, I decided to replace is with the transpose( permutation ) function adding logic for creating the required permutation.

This led to my second attempt at implementation, which is in this Draft PR:

def np_vecdot_impl(x1, x2, axis=-1, axes=None):
    if axes is None:
        axis1, axis2 = axis, axis
    else:
        axis1, axis2 = axes[0], axes[1]

    # does the same as np.moveaxis(x1, axis1, -1)
    # but moveaxis doesn't behave well in nopython mode
    perm1 = np.empty(x1.ndim, dtype=np.intp)
    for i in range(x1.ndim):
        if i < axis1:
            perm1[i] = i
        elif i == axis1:
            perm1[i] = x1.ndim - 1
        else:
            perm1[i] = i - 1

    # does the same as np.moveaxis(x2, axis2, -1)
    # but moveaxis doesn't behave well in nopython mode
    perm2 = np.empty(x2.ndim, dtype=np.intp)
    for i in range(x2.ndim):
        if i < axis2:
            perm2[i] = i
        elif i == axis2:
            perm2[i] = x2.ndim - 1
        else:
            perm2[i] = i - 1

    left = np.transpose(x1, perm1).conjugate()
    right = np.transpose(x2, perm2)
    product = np.multiply(left, right)
    result = np.sum(product, axis=-1)
    return result

Unfortunately, I ran into a mismatch between the Numba and Python types with the following traceback:

  File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 621, in generic
    disp, new_args = self._get_impl(args, kws)
  File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 720, in _get_impl
    impl, args = self._build_impl(cache_key, args, kws)
  File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 793, in _build_impl
    ovf_result = self._overload_func(*args, **kws)
  File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 912, in np_vecdot_axis
    return np_vecdot_impl(x1, x2, axis=axis)
  File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 1006, in np_vecdot_impl
    if i < axis1:
TypeError: '<' not supported between instances of 'int' and 'Integer'

Things I’ve tried to convert axis1 to int:

  • int(axis1)

  • operator.index(axis1)

  • `literally(axis1)

  • axis1 = int(getattr(axis1, "item", lambda: axis1)())

  • if hasattr(x, "item"): return int(x.item())
    - if hasattr(x, "value"): return int(x.value)

I would welcome any advice so that I can carry this PR across the finish line. I have already spent a lot of time on it, so it would be a shame if it was all for nothing.

Here is the link to the draft PR:

https://github.com/numba/numba/pull/10213#issue-3381016566

Thank you,
Alex

I haven’t looked too deep into the PR, but after a quick glance the cause of the

TypeError: ‘<’ not supported between instances of ‘int’ and ‘Integer’

error appears to be in the use of the overload.

In the linked documentation, it’s described how the function decorated with the overload decorator takes the types of the arguments - hence the “instance of int → instance of Integer” promotion we are observing in the error message (for instance,int64 = Integer(“int64“)will be the argument passed to the overloaded function called with an integer).