Adding multiple types to structref.define_boxing

Hi, following Danny Weitekamp’s structref example (here), I was trying to create a numba StructRefProxy myself.

Basically, I am predefining the types like in the example above, e.g.:

@structref.register
class TestTypeTemplate(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(type)) for name, type in fields)
    
kb_fields = [
    ("value", types.int64),
    ("array", types.UnionType(types=(types.float64[::1], types.int64[::1]))),
]

structref.define_boxing(TestTypeTemplate, Test)

The idea is that the array of values can either be an array of float64 or int64 types.

The example above can’t be used to initialize structref classes unfortunately. For example, when I pass an array of int64 integers to my Test class, I get:

numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
Cannot cast array(int64, 1d, C) to Union[array(float64, 1d, C),array(int64, 1d, C)]: %"inserted.strides" = insertvalue {i8*, i8*, i64, i64, i64*, [1 x i64], [1 x i64]} %"inserted.shape", [1 x i64] %".13", 6

I was wondering if it is actually possible at all in this scenario, to allow users to either provide an array of int64 or float64 values. I have found a workaround, but I don’t like it. Simply put, I can check whether the input is an int64/float64 array and then initialize the class using 2 different constructors (1 for floats, 1 for integers), but that is not something I can use unfortunately.

Yeah it’s possible, for instance like so:

import numpy
from numba import njit, typeof
from numba.core import types
from numba.experimental import structref


class Atest(structref.StructRefProxy):
    def __new__(cls, value, array):
        return atest_constructor(value, array)

    @property
    @njit
    def value(self):
        return self.value

    @property
    @njit
    def array(self):
        return self.array


@structref.register
class AtestTypeTemplate(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(type)) for name, type in fields)


structref.define_proxy(Atest, AtestTypeTemplate, ["value", "array"])


def make_Atest_constructor_sig(array_type):
    return AtestTypeTemplate([("value", types.int64), ("array", array_type)])(types.int64, array_type)


@njit([
    make_Atest_constructor_sig(types.float64[::1]),
    make_Atest_constructor_sig(types.int64[::1])
])
def atest_constructor(value, array):
    return Atest(value, array)


if __name__ == '__main__':
    t1 = Atest(137, numpy.array([314], dtype=numpy.int64))
    t2 = Atest(-141, numpy.array([2.17], dtype=numpy.float64))

    print(t1.value, t1.array)  # 137 [314]
    print(t2.value, t2.array)  # -141 [2.17]

    print(typeof(t1)._fields)  # (('value', int64), ('array', Array(int64, 1, 'C', False, aligned=True)))
    print(typeof(t2)._fields)  # (('value', int64), ('array', Array(float64, 1, 'C', False, aligned=True)))

2 Likes

P.S. Here typing of atest_constructor is optional, you can just leave it at njit and let numba figure it out for you. Also, behind the scenes structref.define_contructor (invoked by define_poxy here) is doing overloads for different types of the structure’s members.

1 Like

@dovoxi One consideration with @milton 's approach is that the constructor now outputs objects of two different structref types. While they share the same TypeTemplate, the f8 and i8 variants of the structref are not unifiable within numba’s type inference system. So you could certainly use this to instantiate these two kinds of objects on the Python side, and even send them individually into jitted functions as single arguments (jit compilation will build two different specialized endpoints for each type individually). However, you could not, for instance, add objects of these two types to the same list and pass that list into one jitted function. So this solution isn’t quite the same as having a proper union type.

I haven’t used numba much in the last 6 months or so (moved to C++ w/ nanobind in part because of these kinds of issues), but I’m fairly certain that this cannot be achieved in numba without doing some annoying, complex, and not-quite-worth-it tricks that involve treating the arrays as opaque pointers, keeping around some kind of type indicator within each structref object, and casting the opaque pointers back to the appropriate array types as indicated by the type indicator. Or alternatively, having some kind of parent class that defines the type indicator and subclasses with extra typed fields that can be casted to from the parent class. The machinery for that approach is also not directly supported by numba, and requires a lot of specialized code to get working. I talked a bit about that here: Any numba equivalent for casting a raw pointer to a StructRef, Dict, List etc? - #12 by DannyWeitekamp.

If you’re brave enough to poke around for a solution along these lines, I did do something like this in this project. But doing so will have you doing so much non-standard stuff that it will become very difficult to maintain… so hopefully @milton 's is enough for your purposes. Good luck!

3 Likes

Thanks for the comments. BTW just to make things even more pedagogical for the OP, the behind-the-scenes machinery that is responsible for creating these two different struct types happens here. One can swap the caller (define_proxy) with the following code, which makes things more explicit (i.e., keep the boxing for the Python proxy but roll out your own low-level constructor overloads):

from numba.extending import overload 


structref.define_boxing(AtestTypeTemplate, Atest)


@overload(Atest, strict=False)
def ol_Atest(value_type, array_type):
    Atest_type = AtestTypeTemplate([("value", value_type), ("array", array_type)])

    def _(value, array):
        Atest_obj = structref.new(Atest_type)
        Atest_obj.value = value
        Atest_obj.array = array
        return Atest_obj
    return _
1 Like

Thanks for the detailed answer. So I guess I won’t be able to define a numba type, such as:

AtestType = AtestTypeTemplate(kb_fields)

Thanks a lot for the answer. It does indeed work.

You’re welcome. You can also consider using AnyType introduced here, to wrap your arrays into. It leverages a variant of the type erasure technique. In application to your case, it goes like so:

import numpy
from numba import njit, typeof
from numba.core import types
from numba.experimental import structref
from numbox.core.any_type import make_any, AnyType


class Atest(structref.StructRefProxy):
    def __new__(cls, value, array):
        return atest_constructor(value, array)

    @property
    @njit
    def value(self):
        return self.value

    @property
    @njit
    def array(self):
        return self.array

    @array.setter
    @njit
    def array(self, val):
        self.array = val


@structref.register
class AtestTypeTemplate(types.StructRef):
    def preprocess_fields(self, fields):
        return tuple((name, types.unliteral(type)) for name, type in fields)


structref.define_proxy(Atest, AtestTypeTemplate, ["value", "array"])


AtestType = AtestTypeTemplate([
    ("value", types.int64),
    ("array", AnyType)  # `AnyType` from numbox package, its objects are created using `make_any`
])


@njit
def atest_constructor(value, array):
    return Atest(value, make_any(array))


if __name__ == '__main__':
    a1 = numpy.array([314], dtype=numpy.int64)
    a2 = numpy.array([2.17], dtype=numpy.float64)
    a1_ty = typeof(a1)
    a2_ty = typeof(a2)
    t1 = Atest(137, a1)
    t2 = Atest(-141, a2)

    sigs = atest_constructor.nopython_signatures
    assert len(sigs) == 2, "for `a1_ty` and `a2_ty` arguments"
    ret_type = set([sig.return_type for sig in sigs])
    assert len(ret_type) == 1, "`ret_type` is `numba.AtestTypeTemplate(('value', int64), ('array', numba.AnyTypeClass(('p', numba.ErasedTypeClass()),)))`"

    print(t1.value, t1.array.get_as(a1_ty))  # 137 [314]
    print(t2.value, t2.array.get_as(a2_ty))  # -141 [2.17]

    assert typeof(t1) == typeof(t2)

    a3 = numpy.array([1.41], dtype=numpy.float64)
    a3_ty = typeof(a3)
    t1.array.reset(a3)
    print(t1.value, t1.array.get_as(a3_ty))  # 137 [1.41]

2 Likes