I’m new to numba
an struggle with the basics. To optimize my code, I’d wish to increase the performance of the following operation, which sums a tensor (psi
) along the first axis, but weighted with a vector (qweights
).
def naive_scalarflux(psi,qweights):
nq,nx,ny = psi.shape
phi = np.zeros((nx,ny))
for i,qi in enumerate(qweights):
phi += qi* psi[i,:,:]
return phi
For the MWE, the output could look as foolows:
np.random.seed(123)
qweights = np.random.rand(10)
psi = np.random.rand(10,3,3)
naive_scalarflux(psi,qweights)
output>
array([[3.40503587, 2.65773903, 3.40809969],
[2.76027256, 1.65875846, 2.63184221],
[2.09977753, 2.80326648, 2.76517312]])
This can be simplified with np.tensordot
to
def numpy_scalarflux(psi,qweights):
return np.tensordot(qweights,psi,axes =((0),(0)))
and numpy_scalarflux(psi,qweights)
yields the same output as before.
Decorating everything with @njit
and running yields an error:
@njit
def njit_scalarflux(psi,qweights):
return np.tensordot(qweights,psi,axes =((0),(0)))
njit_scalarflux(psi,qweights)
output>
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
<ipython-input-6-f4e880cfce8a> in <module>()
----> 1 njit_scalarflux(psi,qweights)
2 frames
/usr/local/lib/python3.6/dist-packages/numba/six.py in reraise(tp, value, tb)
666 value = tp()
667 if value.__traceback__ is not tb:
--> 668 raise value.with_traceback(tb)
669 raise value
670
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Use of unsupported NumPy function 'numpy.tensordot' or unsupported use of the function.
File "<ipython-input-3-71f9ca8ddd78>", line 13:
def njit_scalarflux(psi,qweights):
return np.tensordot(qweights,psi,axes =((0),(0)))
^
[1] During: typing of get attribute at <ipython-input-3-71f9ca8ddd78> (13)
File "<ipython-input-3-71f9ca8ddd78>", line 13:
def njit_scalarflux(psi,qweights):
return np.tensordot(qweights,psi,axes =((0),(0)))
^
Is np.tensordot
simply not supported or am I doing something wrong? Thanks in advance, help is very much appreciated. Here are some version numbers (I’m on Google Colab):
Python 3.6.9
numba: 0.48.0
numpy: 1.18.5