Numba quadtree jitclass issue

When trying to implement a quadtree with Numba I ran into an issue. I based my implementation on the pseudo-code from Wikipedia. The complicated part is attaching an instance of the same type as an attribute. Creating this child instance works, but as soon as I call it’s method from within the parent, I get an error message. Interestingly, calling the same method directly (outside the parent class) works fine.

My quadtree attempt can be found at the link below, including the error message. The manual insertion is shown at the bottom of this notebook:

I modified the “Bag” example to display the same/similar behavior with slightly less complex code. I also got the overload_method part from that example, found at:

from numba import typed, typeof, njit, int64, types, optional, deferred_type
from numba.experimental import jitclass
from numba.extending import overload_method, overload

bag_type = deferred_type()

spec = [
    ('max_length', int64),
    ('data', types.ListType(int64)),
    ('child', optional(bag_type))
]


@jitclass(spec)
class Bag(object):
    def __init__(self, max_length):
        
        self.max_length = max_length
        self.data = typed.List.empty_list(int64)
        self.child = None
        
    def insert(self, value):
        
        if len(self.data) < self.max_length:
            self.data.append(value)
        else:
            
            if self.child is None:
                self.child = self.add_child()            
            else:
                self.child.insert(value)
        
@overload_method(types.misc.ClassInstanceType, 'add_child')
def ol_bag_add_child(inst,):
    if inst is Bag.class_type.instance_type:
        def impl(inst,):
            return Bag(inst.max_length)
            
        return impl
        
bag_type.define(Bag.class_type.instance_type)

def test():
    mybag = Bag(5)

    print(mybag.data)
    print(mybag.child)

    for i in range(6):
        mybag.insert(i)

    print(mybag.data)
    print(mybag.child)
    print(mybag.child.data)
    
    for i in range(i, i+5):
        mybag.child.insert(i)
        
    print(mybag.child.data)
    
test()

And the error it produces:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
compiler re-entrant to the same function signature
- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.misc.ClassInstanceType'>, 'insert') for instance.jitclass.Bag#25035e61508<max_length:int64,data:ListType[int64],child:OptionalType(DeferredType#2543522858312) i.e. the type 'DeferredType#2543522858312 or None'>)
During: typing of call at <ipython-input-2-ff04cf4531ee> (31)


File "<ipython-input-2-ff04cf4531ee>", line 31:
    def insert(self, value):
        <source elided>
            else:
                self.child.insert(value)
                ^

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.misc.ClassInstanceType'>, 'insert') for instance.jitclass.Bag#25035e61508<max_length:int64,data:ListType[int64],child:OptionalType(DeferredType#2543522858312) i.e. the type 'DeferredType#2543522858312 or None'>)
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

If the line causing the error is removed the example works. Which at least shows that the creation of the child object and it’s insert method both work. It’s calling that method from within the parent which seems the issue.

Any suggestions on how to circumvent this issue are appreciated.

@Rutger list of jitclass gotchas I found so far:

  • recursion is not supported. you have one case of recursion when you call self.child.insert. avoid recursion
  • returning typed lists of jitclass objects from methods not supported. you have to pass the return list as a param
  • pickling is triggered when you can njit functions from methods. no workaround here
  • calling constructor from methods triggers pickling error. workaroud is to use overload_method, but this has some other issues with deferred_types :slight_smile:

This will work:

from numba import typed, typeof, njit, int64, types, optional, deferred_type
from numba.experimental import jitclass
from numba.extending import overload_method, overload

BagType = deferred_type()

spec = [
    ('max_length', int64),
    ('data', types.ListType(int64)),
    ('child', optional(BagType))
]

@jitclass(spec)
class Bag(object):
    def __init__(self, max_length):
        
        self.max_length = max_length
        self.data = typed.List.empty_list(int64)
        self.child = None

BagType.define(Bag.class_type.instance_type)


@overload_method(types.misc.ClassInstanceType, 'child_insert')
def ol_bag_add_child_insert(inst, value):
    if inst is Bag.class_type.instance_type:
        def impl(inst, value):
            if len(inst.data) < inst.max_length:
                inst.data.append(value)
            
        return impl


@overload_method(types.misc.ClassInstanceType, 'insert')
def ol_bag_insert(self, value):
    if self is Bag.class_type.instance_type:
        def impl(self, value):
            print('insert ' + str(value))

            # eliminate recursion here by implementing a fake stack
            stack = typed.List()
            stack.append((self, value))

            while len(stack):
                node = stack[0][0]
                value = stack[0][1]
                stack.pop(0)

                if len(node.data) < node.max_length:
                    node.data.append(value)
                else:
                    if node.child is None:
                        node.child = Bag(node.max_length - 1) #self.add_child() # this works too
                    else:
                        print('child insert')
                        stack.append((node.child, value))
                        # self.child.child_insert(value) # this will work as well because we eliminated the recursion

        return impl


@overload_method(types.misc.ClassInstanceType, 'add_child')
def ol_bag_add_child(inst,):
    if inst is Bag.class_type.instance_type:
        def impl(inst,):
            print('add_child')
            return Bag(inst.max_length - 1)
            
        return impl

@njit
def test():
    mybag = Bag(5)

    print(mybag.data)
    print(mybag.child)

    for i in range(10):
        mybag.insert(i)

    print(mybag.data)
    print(mybag.child)
    print(mybag.child.data)
    
test()

This works too if we eliminate the recursion:


from numba import typed, typeof, njit, int64, types, optional, deferred_type
from numba.experimental import jitclass
from numba.extending import overload_method, overload

BagType = deferred_type()

spec = [
    ('max_length', int64),
    ('data', types.ListType(int64)),
    ('child', optional(BagType))
]

@jitclass(spec)
class Bag(object):
    def __init__(self, max_length):
        
        self.max_length = max_length
        self.data = typed.List.empty_list(int64)
        self.child = None
        
    def child_insert(self, value):
        if len(self.data) < self.max_length:
            self.data.append(value)        

    def insert(self, value):
        if len(self.data) < self.max_length:
            self.data.append(value)
        else:
            # self.get_child(self.child, value)
            if self.child is not None:
                # self.child.insert(value) 
                self.child.child_insert(value) # eliminate recursion
            else:
                self.child = Bag(self.max_length) #self.add_child()  
        
    def add_child(self):
        return Bag(self.max_length)

BagType.define(Bag.class_type.instance_type)


@njit
def test():
    mybag = Bag(5)

    print(mybag.data)
    print(mybag.child)

    for i in range(10):
        mybag.insert(i)

    print(mybag.data)
    print(mybag.child)
    print(mybag.child.data)
    
test()