First-class function in structref error when other member exists

I’m running numba 0.57.1 with python 3.9.16 on Windows 10 Pro.
When I create a structref with a single first-class function member I can add it to a typed list without issue, but when I add a second member the object cannot be added to a typed list.
I’m sure I’m missing something, but a day of searching and reading the docs didn’t get me to a solution.
Minimal reproducer below… the code runs without error if the second element of the fields list is commented out.

import numba
from numba import int64
from numba.core import types
from numba.experimental import structref
from numba.experimental.structref import define_boxing

print(numba.__version__)

@numba.njit('int64(int64, int64)')
def sum(a, b):
    return a + b

sum_type = sum.nopython_signatures[0].as_type()
@structref.register
class TesterTypeTemplate(types.StructRef):
    pass

class Tester(structref.StructRefProxy):
    def __new__(cls, *args):
        return structref.StructRefProxy.__new__(cls, *args)

fields = [
    ('sum', sum_type),
    ('arg1', int64),  # error goes away if you comment this line
]
structref.define_constructor(Tester, TesterTypeTemplate, [field[0] for field in fields])
define_boxing(TesterTypeTemplate, Tester)
TesterType = TesterTypeTemplate(fields)
tester_lst = numba.typed.List.empty_list(TesterType)

arg1 = 1
tester = Tester(sum) if 1 == len(fields) else Tester(sum, arg1)
tester_lst.append(tester)

The (gently edited) output is

python.exe discourse.py 
0.57.1
venv\lib\site-packages\numba\core\utils.py:554: NumbaExperimentalFeatureWarning: First-class function type feature is experimental
  warnings.warn("First-class function type feature is experimental",
Traceback (most recent call last):
  File "discourse.py", line 33, in <module>
    tester_lst.append(tester)
  File "venv\lib\site-packages\numba\typed\typedlist.py", line 344, in append
    _append(self, item)
  File "venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function impl_append at 0x00000297E7058B80>) found for signature:

 >>> impl_append(ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]), ('arg1', int64))], numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000297C75FEA60>))), ('arg1', int64)))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'impl_append': File: numba\typed\listobject.py: Line 592.
    With argument(s): '(ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]), ('arg1', int64))], numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000297C75FEA60>))), ('arg1', int64)))':
   Rejected as the implementation raised a specific error:
     NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
   Cannot cast numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000297C75FEA60>))), ('arg1', int64)) to numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]), ('arg1', int64)): %"inserted.meminfo.1" = insertvalue {i8*} undef, i8* %"arg.item.0", 0
   During: lowering "casteditem = call $2load_global.0(item, $6load_deref.2, func=$2load_global.0, args=[Var(item, listobject.py:599), Var($6load_deref.2, listobject.py:600)], kws=(), vararg=None, varkwarg=None, target=None)" at D:\work\git1\mims\venv\lib\site-packages\numba\typed\listobject.py (600)
  raised from venv\lib\site-packages\numba\core\base.py:701

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]), ('arg1', int64))])
During: typing of call at venv\lib\site-packages\numba\typed\typedlist.py (82)

@goykhman discovered that constructing the structref inside a jitted function causes a similar error even when the ‘other member’ does not exist.

import numba
from numba.core import types
from numba.experimental import structref
from numba.experimental.structref import define_boxing

print(numba.__version__)

@numba.njit('int64(int64, int64)')
def sum_it(a, b):
    return a + b

sum_type = sum_it.nopython_signatures[0].as_type()
@structref.register
class TesterTypeTemplate(types.StructRef):
    pass

class Tester(structref.StructRefProxy):
    def __new__(cls, *args):
        return structref.StructRefProxy.__new__(cls, *args)

fields = [
    ('sum', sum_type),
]
structref.define_constructor(Tester, TesterTypeTemplate, [field[0] for field in fields])
define_boxing(TesterTypeTemplate, Tester)
TesterType = TesterTypeTemplate(fields)
tester_lst = numba.typed.List.empty_list(TesterType)

tester = Tester(sum_it)
tester_lst.append(tester)

@numba.njit
def append(lst, val):  # works fine if structref is constructed externally
    lst.append(val)

append(tester_lst, tester)

@numba.njit
def create_append(lst):  # fails if structref is constructed in jit-code
    val = Tester(sum_it)
    lst.append(val)

create_append(tester_lst)

produces

venv\Scripts\python.exe discourse1.py 
0.57.1
venv\lib\site-packages\numba\core\utils.py:554: NumbaExperimentalFeatureWarning: First-class function type feature is experimental
  warnings.warn("First-class function type feature is experimental",
Traceback (most recent call last):
  File "discourse1.py", line 43, in <module>
    create_append(tester_lst)
  File "venv\lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "venv\lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function impl_append at 0x00000233BF778B80>) found for signature:

 >>> impl_append(ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]),)], numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000233A0CEFA60>))),))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'impl_append': File: numba\typed\listobject.py: Line 592.
    With argument(s): '(ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]),)], numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000233A0CEFA60>))),))':
   Rejected as the implementation raised a specific error:
     NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
   Cannot cast numba.TesterTypeTemplate(('sum', type(CPUDispatcher(<function sum_it at 0x00000233A0CEFA60>))),) to numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]),): %"inserted.meminfo.1" = insertvalue {i8*} undef, i8* %"arg.item.0", 0
   During: lowering "casteditem = call $2load_global.0(item, $6load_deref.2, func=$2load_global.0, args=[Var(item, listobject.py:599), Var($6load_deref.2, listobject.py:600)], kws=(), vararg=None, varkwarg=None, target=None)" at venv\lib\site-packages\numba\typed\listobject.py (600)
  raised from venv\lib\site-packages\numba\core\base.py:701

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[numba.TesterTypeTemplate(('sum', FunctionType[int64(int64, int64)]),)])
During: typing of call at discourse1.py (41)


File "discourse1.py", line 41:
def create_append(lst):  # fails if structref is constructed in jit-code
    <source elided>
    val = Tester(sum_it)
    lst.append(val)
    ^


Process finished with exit code 1

Looks like from the error that the jit compilation has a conflict with treating sum() as a literal function reference instead of an opaque first class function. One way to fix this is to take control of how your structref’s constructor is defined instead of letting numba build it (and infer the types) for you. There are a several advantages to this:

  1. you can explicitly control its signature, which is helpful in this case for forcing numba to treat certain arguments/fields as a generic opaque functions.
  2. but also allows you to control default arg values, which perhaps you are also after here, since I see both Tester(sum) and Tester(sum, arg1)

As a side note, I see from your if 1 == len(fields) that you are considering some kind of branching situation with different versions of TesterType. I would advise against creating more than one type specialization per type template if you can avoid it. Most of the practical cases where this might be useful lead to a sort of hyper-specialization nightmare situation where you end up with lots of variants of the same function, from different variants of the same structref, and any speedup benefits that might have been gained by applying those specializations are small enough that they don’t justify the added compile time, and worse, might incur more runtime overhead from needing to import more code. Plus it will cause annoying unification issues—for instance you’ll have a hard time adding two structrefs with different specialization types to the same list.

Anyway here is your code with a custom constructor:

import numba
from numba import int64
from numba.core import types
from numba.experimental import structref
from numba.experimental.structref import define_boxing, new

print(numba.__version__)

@numba.njit('int64(int64, int64)')
def sum(a, b):
    return a + b

sum_type = sum.nopython_signatures[0].as_type()
print(sum_type)

@structref.register
class TesterTypeTemplate(types.StructRef):
    def __str__(self):
        return "TesterType"

class Tester(structref.StructRefProxy):
    def __new__(cls, func, arg1=1):
        # return structref.StructRefProxy.__new__(cls, *args)
        return tester_ctor(func, arg1)

fields = [
    ('sum', sum_type),
    ('arg1', int64),  # error goes away if you comment this line
]
structref.define_constructor(Tester, TesterTypeTemplate, [field[0] for field in fields])
define_boxing(TesterTypeTemplate, Tester)
TesterType = TesterTypeTemplate(fields)


from numba import i8, njit
@njit(TesterType(types.FunctionType(i8(i8,i8)), i8))
def tester_ctor(func, arg1):
    st = new(TesterType)
    st.sum = func
    st.arg1 = arg1
    return st


tester_lst = numba.typed.List.empty_list(TesterType)

arg1 = 1
tester = Tester(sum) if 1 == len(fields) else Tester(sum, arg1)
tester_lst.append(tester)
print(tester)
1 Like

Thanks, this is perfect! And the commentary explaining how it works is super helpful.

This also helps with something I hadn’t understood well… how to make sure that the structref is constructed only with the types specified in fields. I had been using a constructor function like tester_ctor but without attaching it to __new__ it was possible for users to create instances of Tester that didn’t conform to the types defined in TesterType. Your solution solves that quite nicely.

Regarding the if 1 == len(fields), I hadn’t contemplated any trickiness… that was simply to to make the reproducer work/not-work with by commenting a single line. Also very useful comments. It’s great to get thoughts from someone who has been there and done that.

1 Like