I would like to generate a typed dictionary within a Numba njitted function. The goal of this dictionary is to run numpy functions (value) based on a user input (key:str).
Using a dict you might have the same issue. You have to define the types of your key, value pairs.
You could instead use an enum or a function pointer.
The pointer seems to have some overhead and is much slower for the simple benchmark below.
# =============================================================================
# Imports
# =============================================================================
from enum import IntEnum, auto
import numpy as np
from numba import njit
import numba.types as nbt
# =============================================================================
# Implementation using enum
# =============================================================================
class Operation(IntEnum):
MIN = auto()
MAX = auto()
@njit
def foo(arr, op):
match op:
case Operation.MIN:
return np.min(arr)
case Operation.MAX:
return np.max(arr)
case _:
raise NotImplementedError("Operation not implemented")
# Example usage:
arr = np.array([1, 2, 3, 4, 5], dtype=np.float64)
print(foo(arr, Operation.MIN))
# =============================================================================
# Implementation using function pointer
# =============================================================================
@njit(nbt.f8(nbt.f8[:]))
def nb_min(arr):
return np.min(arr)
@njit(nbt.f8(nbt.FunctionType(nbt.f8(nbt.f8[:])), nbt.f8[:]))
def baz(func, arr):
return func(arr)
# Example usage:
arr = np.array([1, 2, 3, 4, 5], dtype=np.float64)
print(baz(nb_min, arr))
# =============================================================================
# Speed check Enum vs Pointer
# =============================================================================
arr = np.random.normal(0, 1, 1000)
N = 100_000
%timeit [foo(arr, Operation.MIN) for i in range(N)]
%timeit [baz(nb_min, arr) for i in range(N)]
# 1.33 s ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 3.11 s ± 53.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N = 1000
%timeit [foo(arr, Operation.MIN) for i in range(N)]
%timeit [baz(nb_min, arr) for i in range(N)]
# 13.5 ms ± 83.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 30.9 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)