Why this:
from numba import typed, typeof, njit, int64, types, optional, deferred_type
from numba.experimental import jitclass
from numba.extending import overload_method, overload
BagType = deferred_type()
spec = [
('max_length', int64),
('data', types.ListType(int64)),
('child', optional(BagType))
]
@jitclass(spec)
class Bag(object):
def __init__(self, max_length):
self.max_length = max_length
self.data = typed.List.empty_list(int64)
self.child = None
def child_insert(self, value):
if len(self.data) < self.max_length:
self.data.append(value)
def insert(self, value):
if len(self.data) < self.max_length:
self.data.append(value)
else:
# self.get_child(self.child, value)
if self.child is not None:
# self.child.insert(value)
self.child.child_insert(value) # eliminate recursion
else:
self.child = Bag(self.max_length) #self.add_child()
def add_child(self):
return Bag(self.max_length)
BagType.define(Bag.class_type.instance_type)
@njit
def test():
mybag = Bag(5)
print(mybag.data)
print(mybag.child)
for i in range(10):
mybag.insert(i)
print(mybag.data)
print(mybag.child)
print(mybag.child.data)
test()
compiles much faster than this:
from numba import typed, typeof, njit, int64, types, optional, deferred_type
from numba.experimental import jitclass
from numba.extending import overload_method, overload
BagType = deferred_type()
spec = [
('max_length', int64),
('data', types.ListType(int64)),
('child', optional(BagType))
]
@jitclass(spec)
class Bag(object):
def __init__(self, max_length):
self.max_length = max_length
self.data = typed.List.empty_list(int64)
self.child = None
BagType.define(Bag.class_type.instance_type)
@overload_method(types.misc.ClassInstanceType, 'child_insert')
def ol_bag_add_child_insert(inst, value):
if inst is Bag.class_type.instance_type:
def impl(inst, value):
if len(inst.data) < inst.max_length:
inst.data.append(value)
return impl
@overload_method(types.misc.ClassInstanceType, 'insert')
def ol_bag_insert(self, value):
if self is Bag.class_type.instance_type:
def impl(self, value):
print('insert ' + str(value))
# eliminate recursion here by implementing a fake stack
stack = typed.List()
stack.append((self, value))
while len(stack):
node = stack[0][0]
value = stack[0][1]
stack.pop(0)
if len(node.data) < node.max_length:
node.data.append(value)
else:
if node.child is None:
node.child = Bag(node.max_length - 1) #self.add_child() # this works too
else:
print('child insert')
stack.append((node.child, value))
# self.child.child_insert(value) # this will work as well because we eliminated the recursion
return impl
@overload_method(types.misc.ClassInstanceType, 'add_child')
def ol_bag_add_child(inst,):
if inst is Bag.class_type.instance_type:
def impl(inst,):
print('add_child')
return Bag(inst.max_length - 1)
return impl
@njit
def test():
mybag = Bag(5)
print(mybag.data)
print(mybag.child)
for i in range(10):
mybag.insert(i)
print(mybag.data)
print(mybag.child)
print(mybag.child.data)
test()
They are doing more or less the same thing…