Hi, I am running into a wall trying to get the following to work. I would like to have a heterogenous set of classes which all implement a “value” method. Think classes all derived from some abstract base class.
A concrete example of what I mean is:
def no_numba_test():
class A:
def __init__(self):
pass
def value(self, a: int, b: int) -> float:
return a + b
class B:
def __init__(self):
pass
def value(self, a: int, b: int) -> float:
return a - b
class Dispatch:
def __init__(self, objs):
self.objs = objs
def value(self, a: int, b: int) -> list[float]:
return [f.value(a, b) for f in self.objs]
objs = [A(), B()]
dispatch = Dispatch(objs)
assert dispatch.value(4, 3) == [7, 1]
Trying to make this happen in numba one immediately hits the problem that self.objs
in Dispatch must have homogenous type. The natural solution is to have its type be a List over the abstract base class for A and B, but I understand this isn’t currently supported?
As an alternative I considered making self.objs
only hold the value
methods themselves, which does have a uniform type. So for example self.objs
would have type List(types.FunctionType(types.float64(types.int64, types.int64)))
This actually works for explicitly defined functions, but fails if you then try to add a function based on the value method of a class.
Example:
def numba_test():
from numba import int64, float64, typed, types
function_sig = types.FunctionType(types.float64(types.int64, types.int64))
@nb.njit(float64(int64, int64))
def imp0(a: int, b: int) -> float:
return a + b
@nb.njit(float64(int64, int64))
def imp1(a: int, b: int) -> float:
return a - b
def wrap(T): # Here T is a class instance
"""
Wrapper for turning class value method into standalone function
"""
@nb.njit(float64(int64, int64))
def imp(a: int, b: int) -> float:
return T.value(a, b)
return imp
@nb.experimental.jitclass
class A:
def __init__(self):
pass
def value(self, a: int, b: int) -> float:
return a + b
spec = [("objs", nb.typeof(typed.List().empty_list(function_sig)))]
@nb.experimental.jitclass(spec)
class Dispatch:
def __init__(self, objs):
self.objs = objs
def value(self, a: int, b: int) -> list[float]:
return [f(a, b) for f in self.objs]
objs = typed.List().empty_list(function_sig) objs.append(imp0)
objs.append(imp1)
objs.append(wrap(A())) # !!! this will cause compilation to fail !!!
dispatch = Dispatch(objs)
Compilation fails for the above with:
File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/lowering.py", line 463, in lower_inst
val = self.lower_assign(ty, inst)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/lowering.py", line 669, in lower_assign
res = self.context.get_constant_generic(self.builder, ty,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 505, in get_constant_generic
impl = self._get_constants.find((ty,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 45, in find
out = self._find(sig)
^^^^^^^^^^^^^^^
File "/home/bludot/miniconda3/envs/monumentgcp2/lib/python3.11/site-packages/numba/core/base.py", line 54, in _find
raise errors.NumbaNotImplementedError(f'{self}, {sig}')
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x7fa3599c6550>, (instance.jitclass.A#7fa343402610<>,)
During: lowering "$6load_deref.0 = freevar(T: <numba.experimental.jitclass.boxing.A object at 0x7fa3433f5150>)" at /mnt/c/Users/maxat/src/monument/local/example.py (55)```
Any help or suggestions for alternative approaches would be really appreciated.