Hello, I am submitting this draft PR because I need help understanding the error message.
I took on the task of trying to implement the overload for numpy.vecdot inside numba.
Starting at Numpy Supported Features (
https://numba.readthedocs.io/en/stable/reference/numpysupported.html#
), I figured out that Numba currently has ufuncs to support Numpy behavior for sum and multiplication.I also found that the following formula in Numpy behaves identically to the numpy.vecdot function:sum( moveaxis( arr1, axis1, -1).conjugate() * moveaxis( arr2, axis2, -1), axis = -1 )
This inspired the first version of my attempted implementation:
def np_vecdot_impl(x1, x2, axis=-1, axes=None):
if axes is None:
axis1, axis2 = axis, axis
else:
axis1, axis2 = axes[0], axes[1]
#normalize axis1 and axis2
left = np.moveaxis(x1, axis1, -1).conjugate()
right = np.moveaxis(x2, axis2, -1)
product = np.multiply(left, right)
result = np.sum(product, axis=-1)
return result
Unfortunately, Numba’s moveaxis implementation returns a slicing error with the following traceback:
File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 912, in np_vecdot_axis
return np_vecdot_impl(x1, x2, axis=axis)
File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 1029, in np_vecdot_impl
left = np.moveaxis(x1, axis1, -1).conjugate()
File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1503, in moveaxis
source = normalize_axis_tuple(source, a.ndim, 'source')
File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1435, in normalize_axis_tuple
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
File "C:\Users\Oleksiy\miniconda3\envs\numbaenv\lib\site-packages\numpy\_core\numeric.py", line 1435, in <listcomp>
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
File "D:\Code Projects\Open Source\numba\numba\core\types\abstract.py", line 189, in __getitem__
ndim, layout = self._determine_array_spec(args)
File "D:\Code Projects\Open Source\numba\numba\core\types\abstract.py", line 214, in _determine_array_spec
raise KeyError(f"Can only index numba types with slices with no start or stop, got {args}.")
KeyError: 'Can only index numba types with slices with no start or stop, got 0.'
Seeing that there was an issue with the moveaxis() function, I decided to replace is with the transpose( permutation ) function adding logic for creating the required permutation.
This led to my second attempt at implementation, which is in this Draft PR:
def np_vecdot_impl(x1, x2, axis=-1, axes=None):
if axes is None:
axis1, axis2 = axis, axis
else:
axis1, axis2 = axes[0], axes[1]
# does the same as np.moveaxis(x1, axis1, -1)
# but moveaxis doesn't behave well in nopython mode
perm1 = np.empty(x1.ndim, dtype=np.intp)
for i in range(x1.ndim):
if i < axis1:
perm1[i] = i
elif i == axis1:
perm1[i] = x1.ndim - 1
else:
perm1[i] = i - 1
# does the same as np.moveaxis(x2, axis2, -1)
# but moveaxis doesn't behave well in nopython mode
perm2 = np.empty(x2.ndim, dtype=np.intp)
for i in range(x2.ndim):
if i < axis2:
perm2[i] = i
elif i == axis2:
perm2[i] = x2.ndim - 1
else:
perm2[i] = i - 1
left = np.transpose(x1, perm1).conjugate()
right = np.transpose(x2, perm2)
product = np.multiply(left, right)
result = np.sum(product, axis=-1)
return result
Unfortunately, I ran into a mismatch between the Numba and Python types with the following traceback:
File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 621, in generic
disp, new_args = self._get_impl(args, kws)
File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 720, in _get_impl
impl, args = self._build_impl(cache_key, args, kws)
File "D:\Code Projects\Open Source\numba\numba\core\typing\templates.py", line 793, in _build_impl
ovf_result = self._overload_func(*args, **kws)
File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 912, in np_vecdot_axis
return np_vecdot_impl(x1, x2, axis=axis)
File "D:\Code Projects\Open Source\numba\numba\np\linalg.py", line 1006, in np_vecdot_impl
if i < axis1:
TypeError: '<' not supported between instances of 'int' and 'Integer'
Things I’ve tried to convert axis1 to int:
-
int(axis1) -
operator.index(axis1) -
`literally(axis1)
-
axis1 = int(getattr(axis1, "item", lambda: axis1)()) -
if hasattr(x, "item"): return int(x.item())
-if hasattr(x, "value"): return int(x.value)
I would welcome any advice so that I can carry this PR across the finish line. I have already spent a lot of time on it, so it would be a shame if it was all for nothing.
Here is the link to the draft PR:
https://github.com/numba/numba/pull/10213#issue-3381016566
Thank you,
Alex