Why numba decorator cause 10 times slower run

I don’t know why nb.njit have adverse effect on the performance of my code:

lst = [np.array([[1, 2],
                 [3, 4]]),
       np.array([[1, 2, 3],
                 [4, 5, 6]]),
       np.array([[1, 2],
                 [3, 4],
                 [5, 6]])]

lst = lst * 10000

def new_(lst):
    maxx = 0
    maxy = 0
    for x in lst:
       maxx = max(x.shape[0], maxx)
       maxy = max(x.shape[1], maxy)

    arr = np.zeros((len(lst), maxx, maxy))
    for i in range(len(lst)):
        arr[i, :lst[i].shape[0], :lst[i].shape[1]] = lst[i]
    return arr

@nb.njit(nb.float64[:, :, ::1](nb.types.List(nb.int_[:, ::1], reflected=True)))
def numba_s2(lst):
    maxx = 0
    maxy = 0
    for x in lst:
        maxx = max(x.shape[0], maxx)
        maxy = max(x.shape[1], maxy)

    arr = np.zeros((len(lst), maxx, maxy))
    for i in range(len(lst)):
        arr[i, :lst[i].shape[0], :lst[i].shape[1]] = lst[i]
    return arr

It shows a warning too on both my machine and gloogle colab:

/usr/local/lib/python3.7/dist-packages/numba/core/ir_utils.py:2147: NumbaPendingDeprecationWarning: 
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'lst' of function 'numba_s2'.

For more information visit https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types

File "<ipython-input-2-22beb8072bc3>", line 33:
@nb.njit(nb.float64[:, :, ::1](nb.types.List(nb.int_[:, ::1], reflected=True)))
def numba_s2(lst):
^

  warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))

Is it related to this warning??

Can numba be applied efficiently on this code to get better performances?

my machine:

  • python 3.10
  • numpy 1.22.3
  • numba 0.55.2

hi @Ali_Sh how are you measuring the time? Are you including the compilation time in the comparison?

@luk-f-a I have used %timeit (%timeit -n10 numba_s2(lst)) on Colab. Did you check the performance and it is reasonable in your test?

The warning you’re seeing is definitely a clue. You’re better off using a typed List.

You can create a Numba typed List from an existing Python list using:

lst_nb = nb.typed.List(lst)

You can then drop the reflected=True, and I also think you should use nb.types.ListType in the signature.

@nb.njit(nb.float64[:, :, :](nb.types.ListType(nb.int_[:, ::1])))
def numba_s2(lst):
    maxx = 0
    maxy = 0
    for x in lst:
        maxx = max(x.shape[0], maxx)
        maxy = max(x.shape[1], maxy)

    arr = np.zeros((len(lst), maxx, maxy))
    for i in range(len(lst)):
        arr[i, :lst[i].shape[0], :lst[i].shape[1]] = lst[i]
    return arr

numba_s2(lst_nb)

For me that makes this function about 15x faster compared to the pure Python one, using the same size input as in your OP.

1 Like