Using @njit with numpy.tensordot

I’m new to numba an struggle with the basics. To optimize my code, I’d wish to increase the performance of the following operation, which sums a tensor (psi) along the first axis, but weighted with a vector (qweights).

def naive_scalarflux(psi,qweights):
  nq,nx,ny = psi.shape
  phi = np.zeros((nx,ny))
  for i,qi in enumerate(qweights):
    phi += qi* psi[i,:,:]
  return phi

For the MWE, the output could look as foolows:

np.random.seed(123)
qweights = np.random.rand(10)
psi = np.random.rand(10,3,3)
naive_scalarflux(psi,qweights)

output> 
array([[3.40503587, 2.65773903, 3.40809969],
       [2.76027256, 1.65875846, 2.63184221],
       [2.09977753, 2.80326648, 2.76517312]]) 

This can be simplified with np.tensordot to

def numpy_scalarflux(psi,qweights): 
  return np.tensordot(qweights,psi,axes =((0),(0))) 

and numpy_scalarflux(psi,qweights) yields the same output as before.

Decorating everything with @njit and running yields an error:

@njit
def njit_scalarflux(psi,qweights): 
  return np.tensordot(qweights,psi,axes =((0),(0)))

njit_scalarflux(psi,qweights)
output>

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
<ipython-input-6-f4e880cfce8a> in <module>()
----> 1 njit_scalarflux(psi,qweights)

2 frames
/usr/local/lib/python3.6/dist-packages/numba/six.py in reraise(tp, value, tb)
    666             value = tp()
    667         if value.__traceback__ is not tb:
--> 668             raise value.with_traceback(tb)
    669         raise value
    670 

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Use of unsupported NumPy function 'numpy.tensordot' or unsupported use of the function.

File "<ipython-input-3-71f9ca8ddd78>", line 13:
def njit_scalarflux(psi,qweights): 
  return np.tensordot(qweights,psi,axes =((0),(0)))
  ^

[1] During: typing of get attribute at <ipython-input-3-71f9ca8ddd78> (13)

File "<ipython-input-3-71f9ca8ddd78>", line 13:
def njit_scalarflux(psi,qweights): 
  return np.tensordot(qweights,psi,axes =((0),(0)))
  ^

Is np.tensordot simply not supported or am I doing something wrong? Thanks in advance, help is very much appreciated. Here are some version numbers (I’m on Google Colab):

Python 3.6.9
numba: 0.48.0
numpy: 1.18.5

hi @camminady,

it seems that tensordot is not supported. You can find the full list of supported functions here http://numba.pydata.org/numba-doc/latest/reference/numpysupported.html

Have you tried decorating your first (naive) version with njit? In Numba-decorated functions there’s no performance penalty for writing explicit loops, as there is in interpreted python.

Thanks to your reproducer I could test it myself, and I found that the naive_scalarflux with njit is 10x faster than numpy_scalarflux.

Luk

1 Like

Thank you @luk-f-a for your performance analysis. If the naive version + @njit is so much faster, I will simply use that!

(Since I’m also new to discourse, I’m not sure if I need to accept answer to close the thread. As the problem is solved, I’d like to do so.)

yeah, Numba does very well on pure-loop versions.

There’s no need to accept the answer like stackoverflow, just saying it worked for you is enough, in case someone has the same problem and sees the thread.