Any numba equivalent for casting a raw pointer to a StructRef, Dict, List etc?

Sorry both for the delayed reply. I’ve just returned from traveling.

@nelson2005 thanks for the pointer. I’ll take a look. I’ve been avoiding SQL/database stuff for various reasons, a big one being the need for lots of custom functionality, but it might be worth taking a deeper look into that space.

@luk-f-a so if you have something like BaseClass, and SubClass1 and SubClass2 which are subclasses of BaseClass. Then to keep both SubClass1 and SubClass2 in the same container (for instance a typed List) you would need to cast them both to BaseClass add them to the container.

Now lets say that we want to write an njitted function that loops over the elements in our container and does something with them. Our container will have type ListType(BaseClass) so when we iterate through it each item will be of type BaseClass. At this point numba’s multiple dispatch machinery cannot help us because the decision point lives inside code that is eventually compiled down to LLVM and numba’s multiple dispatch machinery lives at the interface between python and the numba runtime (i.e. it decides what to run based off of the python types coming in).

So as I described in 3) we have a few options, which are along the lines of how we would approach the issue in a compiled language like C++. Either we keep an attribute in our BaseClass (like ‘type_id’ or something) that can help us identify the true type of an object (i.e. the type it was instantiated as) so we can run the correct implementation of our target function on it. If you have a small finite number of types you can just use if-else statement to choose the implementation (which might recast the object back to SubClass1 or SubClass2 to utilize attributes not in BaseClass). In this case all possible implementations are compiled into the function that holds our loop.

We can also execute our target implementation dynamically. One possible way at this would be to build a method table i.e. a typed Dict of type_id → FunctionType(out_type(BaseClass)), fill this on the startup of your program and pass it in as an argument with each call to your function. Alternatively you can assign the target implementation function to an attribute of BaseClass.

Here is some code showing some of these ideas in action, forgive the abuse of CRE (my project) utilities, you can poke around the previous link to see their implementation.

from numba import njit, i8
from numba.types import FunctionType
from numba.typed import List, Dict
from cre.structref import define_structref
from cre.utils import cast_structref,_obj_cast_codegen
from numba.experimental.structref import new
from numba.core.imputils import (lower_cast)

base_members = {"type_id" : i8}
BaseClass, BaseClassType = define_structref("BaseClass", 
    base_members, define_constructor=False)

base_exec_members = {**base_members, "get_thing" : FunctionType(i8(BaseClassType))}
BaseExecutable, BaseExecutableType = define_structref("BaseExecutable", 
    base_exec_members, define_constructor=False)

SubClassA, SubClassAType = define_structref("SubClassA", 
    {**base_exec_members, 'A' : i8}, define_constructor=False)
SubClassB, SubClassBType = define_structref("SubClassB", 
    {**base_exec_members, 'B' : i8}, define_constructor=False)

# Allow automatic upcasting from SubclassA to BaseClassType
@lower_cast(SubClassAType, BaseClassType)
def upcast_A(context, builder, fromty, toty, val):
    return _obj_cast_codegen(context, builder, val, fromty, toty)

# Allow automatic upcasting from SubclassA to BaseClassType
@lower_cast(SubClassBType, BaseClassType)
def upcast_B(context, builder, fromty, toty, val):
    return _obj_cast_codegen(context, builder, val, fromty, toty)


# get_thing() implementations for A and B
@njit(i8(BaseClassType), cache=True)
def get_thing_A(st):
    return cast_structref(SubClassAType, st).A

@njit(i8(BaseClassType), cache=True)
def get_thing_B(st):
    return cast_structref(SubClassBType, st).B

ATYPE_ENUM = 0
BTYPE_ENUM = 1

# Constructor for A
@njit(cache=True)
def SubClassA_ctor(A,get_thing_func=None):
    st = new(SubClassAType)
    st.type_id = ATYPE_ENUM
    if(get_thing_func is not None):
        st.get_thing = get_thing_func
    st.A = A
    return st

# Constructor for B
@njit(cache=True)
def SubClassB_ctor(B,get_thing_func=None):
    st = new(SubClassBType)
    st.type_id = BTYPE_ENUM
    if(get_thing_func is not None):
        st.get_thing = get_thing_func
    st.B = B
    return st

# Init 10 of each type 
@njit(cache=True)
def setup(gt_A=None,gt_B=None):
    L = List.empty_list(BaseClassType)
    for i in range(10):
        # At this point we don't need to explicitly cast to BaseClassType because
        #  we used lower_cast() to register A/B -> Base
        L.append(SubClassA_ctor(i, gt_A)) 
    for i in range(10):
        L.append(SubClassB_ctor(i, gt_B))
    return L


@njit(cache=True)
def get_thing_fixed(x):
    '''Example of hard-coding all method implemenations with else-if'''
    if x.type_id == ATYPE_ENUM:
        return cast_structref(SubClassAType,x).A
    elif x.type_id == BTYPE_ENUM:
        return cast_structref(SubClassBType,x).B
    else:
        return -1

# Need to fill the method table at program startup because the function addresses will change
method_table = Dict.empty(i8, FunctionType(i8(BaseClassType)))
method_table[ATYPE_ENUM] = get_thing_A
method_table[BTYPE_ENUM] = get_thing_B


@njit(cache=True)
def get_thing_dynamic_table(x, method_table):
    '''Example of using a method table for dynamic method implemenations'''
    if(x.type_id in method_table):
        return method_table[x.type_id](x)
    else:
        raise KeyError()

@njit(cache=True)
def get_thing_dynamic_attribute(x):
    '''Example of using dynamic method implemenations via a first-class attribute function'''
    f = cast_structref(BaseExecutableType,x).get_thing
    return f(x)

@njit(cache=True)
def sum_of_stuff_fixed(lst):
    return sum([get_thing_fixed(x) for x in lst])

@njit(cache=True)
def sum_of_stuff_dynamic_table(lst, method_table):
    return sum([get_thing_dynamic_table(x, method_table) for x in lst])

@njit(cache=True)
def sum_of_stuff_dynamic_attribute(lst):
    return sum([get_thing_dynamic_attribute(x) for x in lst])
    

container = setup(get_thing_A,get_thing_B)

print(sum_of_stuff_fixed(container))
print(sum_of_stuff_dynamic_table(container,method_table))
print(sum_of_stuff_dynamic_attribute(container))

Note for the sake of making the sum_of_stuff_dynamic_attribute case above work I resorted to passing the target implementations (i.e. get_thing_A/get_thing_B) to the constructors (via setup) for the subtypes. This is the most elegant solution I’ve found that keeps the code cache=True/AOT friendly. Ideally you would want the address of the function for the target implementation to automatically get built into the constructor for your specialized object, but I haven’t figured out how to do this just yet (in principle this would entail some kind of cross linking). In any case, if you find yourself only instantiating things on the python side then you can usually make this cleaner, by for example setting up the __init__, __new__, or __call__ in your StructRefProxy so that it fills in the implementation automatically.

Another trick to keep in mind if moving around/storing first-class functions is giving you trouble is that you can get the address of a function as an integer (via numba.experimental.function_type._get_wrapper_address) , pass that around as you like and reconstruct the function with cre.utils._func_from_address.

Hope this helps. Let me know if you have any questions.

1 Like