How to make this function compile?

This code snippet works as expected in Python, but fails to compile in @njit mode for options other than A (see comments in affected lines):

import numpy as np
import numba as nb

points = np.float32([[0, -2, -2, 3, 3], [0.1, -2, -2.1, 2.9, 3.1], [0, -1, -1, 3, 3]])

@nb.njit(fastmath=True)
def cluster(points, dMax):
  len = points.shape[0]
  dist = np.full((len, len), np.inf, dtype=points.dtype)

  for i in range(len):
    for j in range(i+1, len):
      dist[i, j] = np.max(np.abs(points[i]-points[j]))

  dist = dist.ravel() # option A
  dist.sort() # option A
  # (dist := dist.ravel()).sort() # option B
  # dist = np.sort(dist, axis=None) # option C
  # (dist := dist.flatten()).sort() # option D

  return dist

print(cluster(points, 0.1))

Do I have to add type hints, or some Numpy methods are just not compatible with Numba?

Hi @pauljurczak,

Some NumPy functions or keyword arguments are not supported yet in Numba. See Supported NumPy features — Numba 0.55.1+0.g76720bf88.dirty-py3.7-linux-x86_64.egg documentation

  • np.sort is only supported without arguments
  • I’m not so sure if the walrus operator is supported.

With that said, the options A and D (without the walrus operator) works

# option A
dist = dist.ravel()
dist.sort()

# option D - without the walrus operator
dist = dist.flatten()
dist.sort()
1 Like