Hi~ I want to sum a 4D array along the last two axis but a TypingError happens during compiling. This is my code and I wonder how can I modify the code to avoid the error.
import numpy as np
from numba import njit, prange
@njit(['float64[:,:](float64[:,:,:,:],int64[:],int64[:])'],parallel=True)
def sum4d(A,ind_row,ind_col):
ny = A.shape[0]
nx = A.shape[1]
im = np.empty((ny,nx))
for i in prange(ny):
for j in range(nx):
im[i,j] = np.sum(A[i,j,ind_row,ind_col])
return im
The error message is
TypingError: Invalid use of Function(<built-in function getitem>) with argument(s) of type(s): (array(float64, 4d, A), Tuple(int64, int64, array(int64, 1d, A), array(int64, 1d, A)))
A is the 4D data. ind_row and ind_col are 1D index arrays containing the row and column indices to be summed, respectively. The function I want to achieve is just like the following pure numpy code
im = np.sum(A[:,:,ind_row,ind_col],axis=2)