It seems that outside a jitted function, operations on numba.typed.List
are very slow compared to the same operation on native Python lists, on my machine the difference is about 20x. In my project I had trouble finding a way to work around the problem, I was wondering if I can get some advice on how to get the best performance.
The situation is that I had one function that accounted for the vast majority of the computational work, which I decided to optimize using numba
. However it was a graph algorithm and it required ragged arrays that constantly have new values appended, meaning I cannot simply use numpy arrays, so I refactored my algorithm to use typed.List
. This sped up the crucial function very much. However, all the code surrounding it became much slower since they operated on typed.List
, becoming the bottleneck in my algorithm, and this surrounding code cannot be jitted because they use other Python objects such as sets.
Iâve reproduced the dilemma in the snippet of code at the bottom of the post. The function update_large_graph
performs the vast majority of the computational work. However, search_large_graph
becomes slow since lst
is a typed.List
.
We can simulate the expected optimal performance by changing search_large_graph(lst, reached)
to search_large_graph(lst2, reached)
, which leaves almost all the execution time inside the jitted function update_large_graph
, since lst2
is a native Python list.
But that does not perform the same computation unless lst
is properly duplicated to lst2
. However, the commented code needed to ensure any updates to lst
are replicated in lst2
is expensive, even with my attempt to optimize this process through the function nb_list_to_array
. I canât seem to find a way to make the algorithm perform close to the expected optimum. Is there a good solution, or is my algorithm unable to be improved unless typed.List
becomes more efficient outside of jitted functions?
The replication of the dilemma:
import numpy as np
import numba
import time
def noop_decorator(func):
return func
jit = numba.njit
NBList = numba.typed.List
# jit = noop_decorator
# NBList = list
@jit
def update_large_graph(lst):
updates = NBList()
val = 0
for i in range(100000):
for j in range(5):
val = (val * 295236 + 2976737) % 395687437
lst.append(val)
updates.append(val)
return updates
@jit
def nb_list_to_array(lst):
arr = np.empty(len(lst), dtype=np.int64)
for i, v in enumerate(lst):
arr[i] = v
return arr
def update_large_graph_v2(lst, updates):
updates = nb_list_to_array(updates).tolist()
for i in range(100000):
lst.append(updates[i])
@jit
def revert_large_graph(lst):
del(lst[1000000:])
def revert_large_graph_v2(lst):
del(lst[1000000:])
# @jit
def search_large_graph(lst, reached):
for i in range(5000):
lst[i] in reached
def workflow_iter(lst, lst2, reached):
updates = update_large_graph(lst)
# update_large_graph_v2(lst2, updates)
search_large_graph(lst, reached)
revert_large_graph(lst)
# revert_large_graph_v2(lst2)
def workflow():
lst = NBList(list(range(1000000)))
lst2 = list(range(1000000))
reached = set(list(range(0, 1000000, 2)))
workflow_iter(lst, lst2, reached)
start = time.time()
for i in range(100):
workflow_iter(lst, lst2, reached)
print(time.time() - start)
if __name__ == '__main__':
workflow()