Sending a jitclass to another jitclass

So i am running into this weird error that when i run this code it works perfectly fine … below are my classes … i am using numba 0.58.1 and on windows 10

spec = {
    "helper_class_self_num": types.float_,
    "helper_class_self_var": types.float_,
}


@jitclass(spec)
class HelperClass(object):
    def __init__(
        self,
        helper_class_self_num,
        helper_class_self_var,
    ):
        self.helper_class_self_num = helper_class_self_num
        self.helper_class_self_var = helper_class_self_var

    def calc_num_one(self, helper_num):
        pass


@jitclass
class SubClassOne(object):
    helper_class_inst: HelperClass

    def __init__(
        self,
        helper_class_self_num,
        helper_class_self_var,
    ):
        self.helper_class_inst = HelperClass(
            helper_class_self_num,
            helper_class_self_var,
        )

    def calc_num_one(self, helper_num):
        return helper_num / self.helper_class_inst.helper_class_self_num

so when i run it in a numba function, like below, I get no problems everything is fine



@njit(cache=True)
def tester_func():
    calculator = SubClassOne(10, 20)
    return calculator.calc_num_one(40)


tester_func()
>>> 6.0

but if i run it with param names in a numba function like below … i get this weird error

@njit(cache=True)
def tester_func():
    calculator = SubClassOne(
        helper_class_self_num=10,
        helper_class_self_var=20,
    )
    return calculator.calc_num_one(40)


tester_func()
LoweringError: Failed in nopython mode pipeline (step: native lowering)
unsupported keyword arguments when calling jitclass.SubClassOne#19f60e22830<helper_class_inst:instance.jitclass.HelperClass#19f60e21ed0<helper_class_self_num:float32,helper_class_self_var:float32>>

File "C:\Users\User\AppData\Local\Temp\ipykernel_35612\2742518833.py", line 46:
def tester_func():
    calculator = SubClassOne(helper_class_self_num=10, helper_class_self_var=20)
    ^

During: lowering "calculator = call $2load_global.0(func=$2load_global.0, args=[], kws=[('helper_class_self_num', Var($const4.1, 2742518833.py:46)), ('helper_class_self_var', Var($const6.2, 2742518833.py:46))], vararg=None, varkwarg=None, target=None)" at C:\Users\User\AppData\Local\Temp\ipykernel_35612\2742518833.py (46)

but if i run this outside of a numba function it works fine

SubClassOne(
    helper_class_self_num=10,
    helper_class_self_var=20,
).calc_num_one(40)
>>> 4.0

I am going to be passing a lot of params so i want to be able to label what i am doing … because without it … it could turn into a real mess real fast

so does anyone know what is going on here?

i have seen a couple people talking about nesting jitclasses and you guys helped me figure out how to do it … so thank you … you guys all saved my life … absolutely saved my life … so if you guys can help me figure this out too … then I will owe all of you more than you can think of!

@epifanio @alanlujan91 - How do I create a jitclass that takes a list of jitclass objects?

@justinblaber @Hannes - Jitclass with input of list of jitclass

@hgrecco @luk-f-a Typed list of jitted functions in jitclass

Also can anyone explain to me why this wouldn’t work?

import numba
import numpy as np
from numba import int64, types, typed, optional
from numba.experimental import jitclass

spec = {"data": int64}


class EmptyClass:
    def __init__(self) -> None:
        pass

    def tester(self):
        print("hello")

    def multi_data(self):
        pass

    def print_data(self):
        pass


@jitclass(spec)
class Foo(EmptyClass):
    def __init__(self, init_data):
        self.data = init_data

    def print_data(self):
        print(self.data)


@jitclass(spec)
class Boo(EmptyClass):
    def __init__(self, init_data):
        self.data = init_data

    def multi_data(self):
        print(self.data * 100)


spec_end = {"logger": numba.types.UnionType((Foo.class_type.instance_type, Boo.class_type.instance_type))}


@jitclass(spec_end)
class Bar:
    def __init__(self, data):
        if data == 0:
            self.logger = Foo(1)
        elif data == 1:
            self.logger = Boo(2)
Bar(0)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: native lowering)
Cannot cast instance.jitclass.Foo#1d61f31b6d0<data:int64> to Union[instance.jitclass.Boo#1d61f318f70<data:int64>,instance.jitclass.Foo#1d61f31b6d0<data:int64>]: %".63" = load {i8*, {i64}*}, {i8*, {i64}*}* %".31"
During: lowering "(self).logger = $14call_function.2" at C:\Users\User\AppData\Local\Temp\ipykernel_19292\1499562187.py (56)
During: resolving callee type: jitclass.Bar#1d61f1cddb0<logger:Union[instance.jitclass.Boo#1d61f318f70<data:int64>,instance.jitclass.Foo#1d61f31b6d0<data:int64>]>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.Bar#1d61f1cddb0<logger:Union[instance.jitclass.Boo#1d61f318f70<data:int64>,instance.jitclass.Foo#1d61f31b6d0<data:int64>]>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

Just a general comment- when using Numba much of the object oriented fanciness doesn’t work.
Numba shines with looping numeric code, not so much with OO.
This goes in spades for jitclass- each time the program runs the jitclass will be a different type (you can see the address as part of the type in the callee type error)

yeah thank you i understand that … but i am asking this question to get confirmation that this is a numba problem and not a programming problem on my end … this way i can know if i can move on or if i need to dig more … so if you know for sure there is no way of doing some python ninja work … that what i am trying to do can work then that would be great … but thanks for your response

the problem is it looks like you can’t try to assign a variable as two different classes

@jitclass()
class ReturnZero:
    def __init__(self):
        pass

    def return_value(self):
        return 0


@jitclass()
class ReturnTen:
    def __init__(self):
        pass

    def return_value(self):
        return 10


@njit(cache=True)
def CachedBagListHolder(num):
    if num < 5:
        calculator = ReturnZero()
    else:
        calculator = ReturnTen()
    x = calculator.return_value() + 9
    print(x)


CachedBagListHolder(10)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify instance.jitclass.ReturnZero#20d3b0e3a90<> and instance.jitclass.ReturnTen#20d0a21b8b0<> for 'calculator.2', defined at C:\Users\User\AppData\Local\Temp\ipykernel_26344\1925693220.py (25)

File "C:\Users\User\AppData\Local\Temp\ipykernel_26344\1925693220.py", line 25:
def CachedBagListHolder(num):
    <source elided>
        calculator = ReturnTen()
    x = calculator.return_value() + 9
    ^

During: typing of assignment at C:\Users\User\AppData\Local\Temp\ipykernel_26344\1925693220.py (25)

File "C:\Users\User\AppData\Local\Temp\ipykernel_26344\1925693220.py", line 25:
def CachedBagListHolder(num):
    <source elided>
        calculator = ReturnTen()
    x = calculator.return_value() + 9
    ^

there has to be some type of way of doing some type of union … where a variable can be either this class or that class

spec_end = {"logger": numba.types.UnionType((Foo.class_type.instance_type, Boo.class_type.instance_type))}

something like this has to work

Hey @QuantFreedom1022 ,

In the example above Numba has a problem to determine the specific type of the local variable “calculator”. It can either be of type “ReturnTen” or “ReturnZero”. Numba does not know these types. You have to introduce them.
I am not sure but there might be a fusion type available in Numba, too. If that is the case you could declare “calculator” as both “ReturnTen” and “ReturnZero”. If that is not the case divide calculator into two local variables of type “ReturnTen” and “ReturnZero”.
If Numba is not able to infer local variables you can do that manually in the signature of the function.
Here is an example:
I hope this helps…

P.S. if you find a fusion type, please give me feedback.

import numba as nb
import numba.types as nbt
from numba.experimental import jitclass

@jitclass
class ReturnZero:
    def __init__(self):
        pass
    def return_value(self):
        return 0

@jitclass
class ReturnTen:
    def __init__(self):
        pass
    def return_value(self):
        return 10

typeZero = ReturnZero.class_type.instance_type
typeTen = ReturnTen.class_type.instance_type

@nb.njit(locals={'calculator1': typeZero, 'calculator2': typeTen, 'x': nbt.i8})
def CachedBagListHolder(num):
    if num < 5:
        calculator1 = ReturnZero()
        x = calculator1.return_value()
    else:
        calculator2 = ReturnTen()
        x = calculator2.return_value()
    print(x)

CachedBagListHolder(4)
CachedBagListHolder(6)
# 0
# 10

hey oyibo … thanks for your reply … but i am trying to avoid if statements because i am going to be looping a ton of times and don’t want to have to check 100,000 times when it will always be calling the same functions

so i am working something like this … where you basically build the class outside of the class then pass in all the functions you would need and then inside the njit function at the end … you can send the list of functions and you can also send whatever number you want of the list of functions to call … you can see how this is coming along

from numba import njit
from numba.core import types

call_str_type = types.void(types.unicode_type)
callee_func_type = types.FunctionType(call_str_type)
sig_str = types.int64(types.int64, types.int64, callee_func_type)


# @numba.njit([sig_str], cache=True)
@njit(cache=True)
def f1(x, y, callee):
    callee("adding x=" + str(x) + " and y=" + str(y))
    return x + y


# @numba.njit([sig_str], cache=True)
@njit(cache=True)
def f2(x, y, callee):
    callee("multiplying x=" + str(x) + " and y=" + str(y))
    return x * y


f_list = typed.List.empty_list(sig_str.as_type())
f_list.append(f1)
f_list.append(f2)


@njit(cache=True)
def callee_str(m):
    print(m)
    pass


@njit(cache=True)
def callee_pass(m):
    pass


p_list = typed.List.empty_list(call_str_type.as_type())
p_list.append(callee_str)
p_list.append(callee_pass)


@jitclass(
    [
        ("funcs", types.ListType(sig_str.as_type())),
        ("printers", types.ListType(call_str_type.as_type())),
    ]
)
class Handler:
    def __init__(self, funcs, printers):
        self.funcs = funcs
        self.printers = printers


@njit(cache=True)
def tester(f_list, p_list):
    x = Handler(f_list, p_list)

    for ind in range(10):
        time: types.int64 = 0
        hello: types.int64 = 0
        for dos in range(200):
            for bar in range(100):
                time += x.funcs[0](ind, dos, x.printers[1]) + bar
                hello += x.funcs[1](ind, 4, x.printers[1])
        result = time - hello
    return result


tester(f_list=f_list, p_list=p_list)

I am going to basically be building everything like this … i will update this later once i figure this out … but i think this is as close as i will get to an answer of how to remove if statements and also call classes inside of a numba njit function