Virtual functions and alternatives

Hi, I am running into a wall trying to get the following to work. I would like to have a heterogenous set of classes which all implement a “value” method. Think classes all derived from some abstract base class.

A concrete example of what I mean is:

def no_numba_test():
    class A:
        def __init__(self):
            pass

        def value(self, a: int, b: int) -> float:
            return a + b


    class B:
        def __init__(self):
            pass

        def value(self, a: int, b: int) -> float:
            return a - b


    class Dispatch:

        def __init__(self, objs):
            self.objs = objs

        def value(self, a: int, b: int) -> list[float]:
            return [f.value(a, b) for f in self.objs]


    objs = [A(), B()]
    dispatch = Dispatch(objs)
    assert dispatch.value(4, 3) == [7, 1]

Trying to make this happen in numba one immediately hits the problem that self.objs in Dispatch must have homogenous type. The natural solution is to have its type be a List over the abstract base class for A and B, but I understand this isn’t currently supported?

As an alternative I considered making self.objs only hold the value methods themselves, which does have a uniform type. So for example self.objs would have type List(types.FunctionType(types.float64(types.int64, types.int64)))

This actually works for explicitly defined functions, but fails if you then try to add a function based on the value method of a class.

Example:

def numba_test():
    from numba import int64, float64, typed, types

    function_sig = types.FunctionType(types.float64(types.int64, types.int64))

    @nb.njit(float64(int64, int64))
    def imp0(a: int, b: int) -> float:
        return a + b

    @nb.njit(float64(int64, int64))
    def imp1(a: int, b: int) -> float:
        return a - b

    def wrap(T):   # Here T is a class instance
       """
       Wrapper for turning class value method into standalone function
       """
        @nb.njit(float64(int64, int64))
        def imp(a: int, b: int) -> float:
            return T.value(a, b)
        return imp


    @nb.experimental.jitclass
    class A:
        def __init__(self):
            pass

        def value(self, a: int, b: int) -> float:
            return a + b

    spec = [("objs", nb.typeof(typed.List().empty_list(function_sig)))]
    @nb.experimental.jitclass(spec)
    class Dispatch:
        def __init__(self, objs):
            self.objs = objs

       def value(self, a: int, b: int) -> list[float]:                                                                                                                                                                                                                                                                                
            return [f(a, b) for f in self.objs]

objs = typed.List().empty_list(function_sig)                                                                                                                                                                                                                                                                                 objs.append(imp0)
objs.append(imp1)
objs.append(wrap(A()))  # !!! this will cause compilation to fail !!!
dispatch = Dispatch(objs)

Compilation fails for the above with:

  File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/lowering.py", line 463, in lower_inst
    val = self.lower_assign(ty, inst)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/lowering.py", line 669, in lower_assign
    res = self.context.get_constant_generic(self.builder, ty,
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 505, in get_constant_generic
    impl = self._get_constants.find((ty,))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 45, in find
    out = self._find(sig)
          ^^^^^^^^^^^^^^^
  File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 54, in _find
    raise errors.NumbaNotImplementedError(f'{self}, {sig}')
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7fa3599c6550>, (instance.jitclass.A#7fa343402610<>,)
During: lowering "$6load_deref.0 = freevar(T: <numba.experimental.jitclass.boxing.A object at 0x7fa3433f5150>)" at /mnt/c/Users/maxat/src/monument/local/example.py (55)```


Any help or suggestions for alternative approaches would be really appreciated.