How to pass dictionaries with Function type as value within njitted func

Hi!

I would like to generate a typed dictionary within a Numba njitted function. The goal of this dictionary is to run numpy functions (value) based on a user input (key:str).

Minimal Case:

def foo(arr, op_str = 'min'):
    ops = {'min':np.nanmin,
           'max':np.nanmax}
    return ops[op_str.lower()](arr)

I’ve tried by building a typed.dict and passing a types.FunctionType as value, didn’t work…

Thanks a lot!

Hey @joueswant ,

Using a dict you might have the same issue. You have to define the types of your key, value pairs.
You could instead use an enum or a function pointer.
The pointer seems to have some overhead and is much slower for the simple benchmark below.

# =============================================================================
# Imports
# =============================================================================
from enum import IntEnum, auto
import numpy as np
from numba import njit
import numba.types as nbt

# =============================================================================
# Implementation using enum
# =============================================================================
class Operation(IntEnum):
    MIN = auto()
    MAX = auto()

@njit
def foo(arr, op):
    match op:
        case Operation.MIN:
            return np.min(arr)
        case Operation.MAX:
            return np.max(arr)
        case _:
            raise NotImplementedError("Operation not implemented")

# Example usage:
arr = np.array([1, 2, 3, 4, 5], dtype=np.float64)
print(foo(arr, Operation.MIN))

# =============================================================================
# Implementation using function pointer
# =============================================================================
@njit(nbt.f8(nbt.f8[:]))
def nb_min(arr):
    return np.min(arr)

@njit(nbt.f8(nbt.FunctionType(nbt.f8(nbt.f8[:])), nbt.f8[:]))
def baz(func, arr):
    return func(arr)

# Example usage:
arr = np.array([1, 2, 3, 4, 5], dtype=np.float64)
print(baz(nb_min, arr))

# =============================================================================
# Speed check Enum vs Pointer
# =============================================================================
arr = np.random.normal(0, 1, 1000)
N = 100_000
%timeit [foo(arr, Operation.MIN) for i in range(N)]
%timeit [baz(nb_min, arr) for i in range(N)]
# 1.33 s ± 5.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 3.11 s ± 53.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
N = 1000
%timeit [foo(arr, Operation.MIN) for i in range(N)]
%timeit [baz(nb_min, arr) for i in range(N)]
# 13.5 ms ± 83.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 30.9 ms ± 1.54 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
3 Likes

Sounds like the solution! Thanks a lot for your prompt response @Oyibo