Hi @steff
first things first: @Rutger already said this, but I think it does not hurt to make this clear once more:
- Numba does not care at all about your type annotations. (but they are of course considered good style by many)
-
numba.typed.List
is an actual container, whereas typing.List
is a type. They are very different things, but easy to mix up
Beyond that (not meant to be harsh):
- I don’t think
minmax
is an appropriate name for that function (given what it does).
- The function seems to have a few logical “gaps” (e.g. data[x] == 0) Unless the lack of processing the intentional
-
len(data)
might not do what you expect once data has more than 1 dimension
- It looks like for cleaner code, that while loop could be a for loop?
- When judging the speed of your function, keep in mind that the first call will trigger compilation and requires extra time compared to all subsequent calls
- Your function is both mutating data in place and returning it, I personally find that a bit odd, but okay
With that said, here are a few ideas of how one could deal with this in alternative ways (timings listed below)
import numpy as np
from numba import njit # Missing import in original MWE
from numba import vectorize, prange
# Union[float, float] = float -> Union does nothing here
# There is no need for thresholds to be a List, let's be more general
from typing import Sequence
@njit
def minmax_orig(data: np.array, thresholds: Sequence[float]) -> np.array:
x = 0
while x < len(data):
if data[x] > 0: # What happens if data[x] == 0 ??
if data[x] > thresholds[0]:
data[x] = thresholds[0]
elif data[x] < 1: # What happens if a threshold is in (-1,1)?
data[x] = 1
elif data[x] < 0:
if data[x] < thresholds[1]:
data[x] = thresholds[1]
elif data[x] > -1:
data[x] = -1
x += 1
return data
@njit
def minmax_scalar_args(data: np.array, lower_thr: float, upper_thr: float) -> np.array:
x = 0
while x < len(data):
if data[x] > 0:
if data[x] > lower_thr:
data[x] = lower_thr
elif data[x] < 1:
data[x] = 1
elif data[x] < 0:
if data[x] < upper_thr:
data[x] = upper_thr
elif data[x] > -1:
data[x] = -1
x += 1
return data
@njit(parallel=True)
def minmax_parfor(data: np.array, lower_thr: float, upper_thr: float) -> np.array:
for x in prange(len(data)):
if data[x] > 0:
if data[x] > lower_thr:
data[x] = lower_thr
elif data[x] < 1:
data[x] = 1
elif data[x] < 0:
if data[x] < upper_thr:
data[x] = upper_thr
elif data[x] > -1:
data[x] = -1
return data
@vectorize
def minmax_vec(data: np.array, lower_thr: float, upper_thr: float) -> np.array:
if data > 0:
if data > lower_thr:
data = lower_thr
elif data < 1:
data = 1
elif data < 0:
if data < upper_thr:
data = upper_thr
elif data > -1:
data = -1
return data
data = np.random.uniform(low=-3, high=3, size=(1_000_000,))
small_data = data[:10]
# Trigger compilation for all testcases
minmax_orig(small_data, [2., -2.])
minmax_orig(small_data, (2., -2.))
minmax_scalar_args(small_data, 2., -2.)
minmax_parfor(small_data, 2., -2.)
minmax_vec(small_data, 2., -2.)
from timeit import timeit
for size in [10, 1_000, 1_000_000]:
print(30*"+")
print(f"{size} elements:\n")
for name, expr in [
("Original implementation with list arg", lambda: minmax_orig(data[:size], [-2., 2.])),
("Original implementation with tuple arg", lambda: minmax_orig(data[:size], (-2., 2.))),
("Implementation with scalar threshold arg", lambda: minmax_scalar_args(data[:size], -2., 2.)),
("Parfor loop with scalar threshold arg", lambda: minmax_parfor(data[:size], -2., 2.)),
("Numba vectorised scalar threshold arg", lambda: minmax_vec(data[:size], -2., 2.))
]:
t = timeit(expr, number=100)
print(name.ljust(50), f"{t:.4e} s")
OUTPUT (timings are of course specific to my machine):
/home/hapahl/anaconda3/lib/python3.8/site-packages/numba/core/ir_utils.py:2031: NumbaPendingDeprecationWarning:
Encountered the use of a type that is scheduled for deprecation: type 'reflected list' found for argument 'thresholds' of function 'minmax_orig'.
For more information visit https://numba.pydata.org/numba-doc/latest/reference/deprecation.html#deprecation-of-reflection-for-list-and-set-types
File "numba_discourse_880.py", line 10:
@njit
def minmax_orig(data: np.array, thresholds: Sequence[float]) -> np.array:
^
warnings.warn(NumbaPendingDeprecationWarning(msg, loc=loc))
++++++++++++++++++++++++++++++
10 elements:
Original implementation with list arg 5.9060e-04 s
Original implementation with tuple arg 6.4200e-05 s
Implementation with scalar threshold arg 6.2100e-05 s
Parfor loop with scalar threshold arg 2.2654e-03 s
Numba vectorised scalar threshold arg 3.5370e-04 s
++++++++++++++++++++++++++++++
1000 elements:
Original implementation with list arg 1.6321e-03 s
Original implementation with tuple arg 2.3860e-04 s
Implementation with scalar threshold arg 2.4910e-04 s
Parfor loop with scalar threshold arg 2.1563e-03 s
Numba vectorised scalar threshold arg 6.4270e-04 s
++++++++++++++++++++++++++++++
1000000 elements:
Original implementation with list arg 4.5491e-01 s
Original implementation with tuple arg 4.3003e-01 s
Implementation with scalar threshold arg 4.2444e-01 s
Parfor loop with scalar threshold arg 9.4272e-02 s
Numba vectorised scalar threshold arg 4.6147e-01 s
Hope that helps a little 