Here’s something to get you started if you want to go down this route:
from numba import njit, literal_unroll, types
from numba.typed import Dict
import numpy as np
from numba.experimental import structref
from numba.extending import overload
import operator
# The idea here is to wrap a typed.Dict in another type, the "TupleKeyDictType".
# The purpose of this is so that operations like __getitem__ and __setitem__
# can be proxied through functions that call `hash` on the key. This makes it
# possible to have something that behaves like a dictionary, but supports
# heterogeneous keys (tuples of varying size/type).
# Define a the new type and register it
@structref.register
class TupleKeyDictType(types.StructRef):
def preprocess_fields(self, fields):
return tuple((name, types.unliteral(typ)) for name, typ in fields)
# Define the Python side proxy class
class TupleKeyDict(structref.StructRefProxy):
@property
def wrapped_dict(self):
return TupleKeyDict_get_wrapped_dict(self)
# Set up the wiring for it, "wrapped_dict" is the only member in the "struct"
# and it refers to the typed.Dict instance in use
structref.define_proxy(TupleKeyDict, TupleKeyDictType, ["wrapped_dict"])
# Overload operator.getitem for the TupleKeyDictType, note how defers the look
# up to the wrapped_dict member and hashes the key
@overload(operator.getitem)
def ol_tkd_getitem(inst, key):
if isinstance(inst, TupleKeyDictType):
def impl(inst, key):
return inst.wrapped_dict[hash(key)]
return impl
# Overload operator.setitem for the TupleKeyDictType, again, it's hashing the
# key before use.
@overload(operator.setitem)
def ol_tkd_setitem(inst, key, value):
if isinstance(inst, TupleKeyDictType):
def impl(inst, key, value):
inst.wrapped_dict[hash(key)] = value
return impl
# quick demonstration
@njit
def foo(keys, values):
# Create a dictionary to wrap
wrapped_dictionary = Dict.empty(types.intp, types.complex128)
# wrap it
tkd_inst = TupleKeyDict(wrapped_dictionary)
# Add some items, this is a bit contrived...
# keys is heterogeneous in dtype (different sized tuples) so needs loop
# body versioning for iteration (i.e. literal_unroll).
idx = 0
for k in literal_unroll(keys):
tkd_inst[k] = values[idx]
idx += 1
# print the wrapped instance
print(tkd_inst.wrapped_dict)
# demo getitem
print("getitem", (1, 2), "gives", tkd_inst[(1, 2)])
keyz = ((1, 2), (3, 4, 5), (6,))
valuez = (1j, 2j, 3j)
foo(keyz, valuez)
hope this helps.