Typed list of jitted functions in jitclass

I am trying to create a jitted class that contains a typed list of jitted functions. I have tried using the following in the spec:

  • numba.types.List(numba.types.Callable)
  • numba.typeof(mylist) where mylist is a List that contains a jitted function.

but I cannot make it work. Any suggestions?

hi @hgrecco, the elements of the list must be of the same type, which means that all functions must have an identical signature. Is that the case?

Your second bullet point should have worked, could you paste the full example? EDIT: it would work if it had more than one function, all of them with the same signature.

Luk

Thanks @luk-f-a for the reply. I wrote a simple example:

import numba

@numba.njit()
def f1(x, y):
    return x + y

@numba.njit()
def f2(x, y):
    return x + y

f1(1., 2.)
f2(1., 2.)

f_list = numba.typed.List()
f_list.append(f1)
f_list.append(f2)

@numba.jitclass([('funcs', numba.typeof(f_list))])
class Handler:

    def __init__(self, funcs):
        self.funcs = funcs


Handler(f_list)

which fails in the seccond append:

Traceback (most recent call last):
 ...
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function impl_append at 0x7fe3c8da6550>) found for signature:

 >>> impl_append(ListType[type(CPUDispatcher(<function f1 at 0x7fe3c8c143a0>))], type(CPUDispatcher(<function f2 at 0x7fe3c8de1430>)))

There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'impl_append': File: numba/typed/listobject.py: Line 589.
    With argument(s): '(ListType[type(CPUDispatcher(<function f1 at 0x7fe3c8c143a0>))], type(CPUDispatcher(<function f2 at 0x7fe3c8de1430>)))':
   Rejected as the implementation raised a specific error:
     LoweringError: Failed in nopython mode pipeline (step: nopython mode backend)
   Cannot cast type(CPUDispatcher(<function f2 at 0x7fe3c8de1430>)) to type(CPUDispatcher(<function f1 at 0x7fe3c8c143a0>)): %".21" = load i8*, i8** %"item"
   
   File "../../../../../anaconda3/envs/sci38/lib/python3.8/site-packages/numba/typed/listobject.py", line 597:
       def impl(l, item):
           casteditem = _cast(item, itemty)
           ^
   
   During: lowering "$8call_function.3 = call $2load_global.0(item, $6load_deref.2, func=$2load_global.0, args=[Var(item, listobject.py:597), Var($6load_deref.2, listobject.py:597)], kws=(), vararg=None)" at /Users/grecco/anaconda3/envs/sci38/lib/python3.8/site-packages/numba/typed/listobject.py (597)
  raised from /Users/grecco/anaconda3/envs/sci38/lib/python3.8/site-packages/numba/core/utils.py:81

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[type(CPUDispatcher(<function f1 at 0x7fe3c8c143a0>))])
During: typing of call at /Users/grecco/anaconda3/envs/sci38/lib/python3.8/site-packages/numba/typed/typedlist.py (66)


File "../../../../../anaconda3/envs/sci38/lib/python3.8/site-packages/numba/typed/typedlist.py", line 66:
def _append(l, item):
    l.append(item)
    ^

Python: 3.8.2
Numba: 0.51.2

The code below is working for me on 0.53dev, and it probably works on 0.52 too.

The key is to generate a valid FunctionType instance. One way to do it is to start with the output type float64 and call it with the input arguments, float64(float64, float64). That produces a Signature object. The signature can be converted into a function type using .as_type(). Alternatively, the function type can be built more explicitly as FunctionType(float64(float64, float64)).

Once you have the function type, the empty list must be created with a an explicit signature pointing to a FunctionType instance. This is the first-class function type that allows two functions with the same signature to be appended to the same list. By default, when using typed.List automatic type inference, the functions f1 and f2 will be recognized as a CPUDispatcher(<function name>) which is a kind of singleton type. That’s why f2 cannot be appended to a list that contains f1, only objects of type CPUDispatcher(<function f1 at ...>) are allowed in that list.

import numba
from numba.core import types

@numba.njit()
def f1(x, y):
    return x + y

@numba.njit()
def f2(x, y):
    return x + y

f1(1., 2.)
f2(1., 2.)

f_list = numba.typed.List.empty_list(types.float64(types.float64, types.float64).as_type())
f_list.append(f1)
f_list.append(f2)

@numba.experimental.jitclass([('funcs', types.ListType(types.float64(types.float64, types.float64).as_type()))])
class Handler:

    def __init__(self, funcs):
        self.funcs = funcs

That worked for me, thanks a lot!

I am now having trouble to create a jitclass which contains an list of jitclass instances. Briefly, I am using this on the spec types.ListType(Event) but then I am getting the following error message:

ListType[instance.jitclass.Event <and more things here>]
to 
ListType[<class 'numba.experimental.jitclass.base.Event'>]

I manage to solve it by creating an instance of Event and then asking for its type. But it feels clunky,

try ListType(Event.class_type) or ListType(Event.class_type.instance_type), I think one of those is the right one.

This worked. Thanks.

I do get a bit of a scary warning when doing this…
:3: NumbaTypeSafetyWarning: unsafe cast from type(CPUDispatcher(<function test_model_debug_virtual_table..debug_module_callback at 0x0000027CC0BDD378>)) to FunctionType[none(void*)]. Precision may be lost.

Is there any way to silence that, or mark it ‘safe’?

bumping this old thread… for functions that have numpy structured arrays with many (hundreds) members, these logged messages are truly voluminous.

@nelson2005

The warning inherits from Warning so a standard warnings filter should work?

That works, thanks! My python skill level exposed once again… :frowning:

Great, glad you got something working :slight_smile:

Is there a way to make this work for functions that have a StructRef as an argument? I have tried replicating the code in the Solution, only with functions that take a StructRef of type MyStructType (defined according to the documentation) so it looks like

@njit
def f1(struct: MyStructType) -> float64
   return struct.a + struct.b

However, if I try to define the list of functions as something like

f_list = nb.typed.List.empty_list(types.float64(MyStructType).as_type())

I get an error

TypeError: float() argument must be a string or a real number, not '_TypeMetaclass'

I’m fairly new at using these structures so maybe there is a simple way to generalize to this cas,e but it is not obvious to me…

What do you get if you call f1 with your struct and print the nopython signatures?

Also, a minimal complete reproducer program that demonstrates the issue would be helpful.

Thank you for your reply! I can post a complete program.

First, this is my definition of MyStruct and MyStructType (which is adapted so that types can be indicated explicitly for each field of MyStruct–I couldn’t find a simpler way to do so in the documentation):

import numpy as np
from numba import njit, float64, int32

from numba.experimental.structref import new, register, StructRefProxy, define_boxing
from numba.core.extending import overload
from numba.core.types import StructRef, Array



def my_struct(name, fields):

### step 1: register the new class of name MyStructType.

    source = f'''
@register
class {name}Type(StructRef):
    pass
'''
#    print(source)
    glbs = globals()
    exec(source, glbs)


### step 2: declare the corresponding class MyStruct, in the same way as done at
### https://numba.pydata.org/numba-doc/dev/extending/high-level.html : we need to redefine
### __new__ to better match the required fields, as well as manually define getters for the fields.
 
    pars = ', '.join([f"{f[0]}" for f in fields])

    def proper(par):
        return f'''
    @property
    def {par}(self):
        return get_{par}(self)
'''

    def njiter(par):
        return f'''
@njit
def get_{par}(self):
    return self.{par}
'''
    
    propers = '\n'.join([proper(f[0]) for f in fields])
    njiters = '\n'.join([njiter(f[0]) for f in fields])
    
    source = f'''
class {name}(StructRefProxy):
    def __new__(cls, {pars}):
        return StructRefProxy.__new__(cls, {pars})

    {propers}

    {njiters}
'''
#    print(source)
    exec(source, glbs)



### step 3: we write the new constructor function for MyStruct, which calls structref.new to
### instantiate a new struct of type ConfigType defined above. This essentially replaces structref.define_constructor. 
### We use extending.overload to set the new function as the constructor for MyStruct.
### Finally, we call define_boxing to make sure that the new struct can be used in both njit and
### normal python code. This is important because StructRef requires the types to be numba types,
### while we might want to instantiate a new struct in python code with numpy types.

    def ster(par):
        return f'''
        st.{par} = {par}
        '''
    sters = '\n'.join([ster(f[0]) for f in fields])

    source = f"""
def ctor({pars}):
    struct_type = {name}Type({fields})
    def impl({pars}):
        st = new(struct_type)
        {sters}
        return st
    return impl
"""
#    print(source)
    exec(source, glbs)

    new = glbs[name]
    newtype = glbs[name+"Type"]

    overload(new)(glbs['ctor'])

    define_boxing(newtype, new)
    
    return new, newtype 

Then, by calling

import numba.typed
from numba.core import types

MyStruct, MyStructType = my_struct("MyStruct", [
    ("a", float64),
    ("b", float64),
])

@njit
def f1(struct):
   return struct.a + struct.b

struct = MyStruct(1,2)

# f1 works as expected
print(f1(struct))


f_list = numba.typed.List.empty_list(types.float64(MyStructType).as_type())

f_list.append(f1)

we see that MyStruct and f1 work as expected, but when trying to create the list I get

/usr/local/lib/python3.10/dist-packages/numba/core/types/scalars.py in cast_python_value(self, value)
    125 
    126     def cast_python_value(self, value):
--> 127         return getattr(np, self.name)(value)
    128 
    129     def __lt__(self, other):

TypeError: float() argument must be a string or a real number, not '_TypeMetaclass'

After your comment, I tried putting the signature in the @njit explicitly and that also didn’t work–so I guess it is just a matter of knowing the right way of expressing the StructRef in the signature. I have no idea how to do that, though. I’m also not sure how to print the nopython signatures of the function as you suggested.

‘ class {name}Type(StructRef):’

Defines the class of the type, not the type itself. To get the actual type you need to pass a list of (attribute,type) pairs to it’s constructor. To keep things straight in my own projects I name things like this:

‘ class {name}TypeClass(StructRef):’

Then

‘{name}Type = {name}TypeClass(…)’

Thank you for pointing this out! I had put the code together following both the documentation and your replies to other issues, and I must have got confused by the different naming conventions. Returning the type properly in my_struct does indeed fix the issue.

I guess this makes my comment off-topic with respect to the original question–but maybe it can still be useful for someone.