Rewrite Heap with structref from jitclass

import numpy as np
import numba as nb

mode = 'set'

class Heap:
    def __init__(self, ktype, dtype, cap=16):
        self.cap = cap
        if mode=='set':
            self.key = np.zeros(cap, dtype=ktype)
        if mode=='map':
            self.key = np.zeros(cap, dtype=ktype)
        if mode!='set':
            self.body = np.zeros(cap, dtype=dtype)
            self.buf = np.zeros(1, dtype=dtype)
        self.size = 0
    
    def push(self, k, v=None):
        if self.size == self.cap: self.expand()
        i = self.size

        if mode=='eval': body = self.body
        if mode=='comp': body = self.body
        if mode=='set': key = self.key
        if mode=='map': key, body = self.key, self.body

        while i!=0:
            pi = (i-1)//2
            if mode=='set': br = key[pi] - k
            if mode=='map': br = key[pi] - k
            if mode=='eval': br = self.eval(body[pi])-self.eval(k)
            if mode=='comp': br = self.comp(body[pi], k)

            if br<=0: break

            if mode=='set': key[i] = key[pi]
            if mode=='map': key[i] = key[pi]
            if mode!='set': body[i] = body[pi]
            i = pi

        if mode=='eval': body[i] = k
        if mode=='comp': body[i] = k
        if mode=='set': key[i] = k
        if mode=='map': key[i], body[i] = k, v
            
        self.size += 1

    def expand(self):
        if mode=='set':
            self.key = np.concatenate(
                (self.key, np.zeros(self.cap, self.key.dtype)))
        if mode=='map':
            self.key = np.concatenate(
                (self.key, np.zeros(self.cap, self.key.dtype)))
        if mode!='set':
            self.body = np.concatenate((self.body, self.body))
        self.cap *= 2

    def pop(self):
        if self.size == 0: return
        self.size -= 1
        # self.swap(0, self.size)
        size = self.size
        
        if mode=='set':
            key = self.key
            key[0], key[size] = key[size], key[0]
            last = key[0]
        if mode=='map':
            key = self.key
            body = self.body
            self.buf[0] = body[0]
            last = key[size]
        if mode=='eval':
            body = self.body
            self.buf[0] = body[0]
            last = body[size]
        if mode=='comp':
            body = self.body
            self.buf[0] = body[0]
            last = body[size]
        
        i = 0
        while True:
            ci = 2 * i + 1
            if ci>=size: break
            
            if mode=='set':
                if ci+1<size and key[ci]>=key[ci+1]: ci+=1
                if last <= key[ci]: break
            if mode=='map':
                if ci+1<size and key[ci]>=key[ci+1]: ci+=1
                if last <= key[ci]: break
            if mode=='eval':
                if ci+1<size and self.eval(body[ci])>=self.eval(body[ci+1]): ci += 1
                if self.eval(last)<=self.eval(body[ci]): break
            if mode=='comp':
                if ci+1<size and self.comp(body[ci], body[ci+1])>=0: ci += 1
                if self.comp(last, body[ci])<=0: cbreak
                
            if mode=='set': key[i] = key[ci]
            if mode=='map': key[i] = key[ci]
            if mode!='set': body[i] = body[ci]
            i = ci
            
        if mode=='set': key[i] = last
        if mode=='map': key[i], body[i] = last, body[size]
        if mode=='eval': body[i] = last
        if mode=='comp': body[i] = last
                
        if mode!='set': return self.buf[0]
        return key[size]

    def top(self):
        if mode!='set':
            return self.body[0]
        return self.key[0]

    def topkey(self): return self.key[0]

    def topvalue(self): return self.body[0]

    def clear(self): self.size = 0

    def __setitem__(self, key, val): 
        self.push(key, val)
    
    def __len__(self):
        return self.size

def istype(obj):
    if isinstance(obj, np.dtype): return True
    return isinstance(obj, type) and isinstance(np.dtype(obj), np.dtype)
                 
def TypedHeap(ktype, vtype=None, jit=True):
    import inspect
    global mode
    if not istype(ktype):
        n = len(inspect.signature(ktype).parameters)
        mode = 'eval' if n==2 else 'comp'
    elif vtype is None: mode = 'set'
    else: mode = 'map'

    exec(inspect.getsource(Heap), dict(globals()), locals())

    fields = [('size', nb.uint32), ('cap', nb.uint32)]
    if mode in {'set', 'map'}:
        fields.append(('key', nb.from_dtype(ktype)[:]))
    if mode in {'map', 'eval', 'comp'}:
        fields += [
              ('body', nb.from_dtype(vtype)[:]),
              ('buf', nb.from_dtype(vtype)[:])]

    class TypedHeap(locals()['Heap']):
        _init_ = Heap.__init__
        if mode=='eval': eval = ktype
        if mode=='comp': comp = ktype

        def __init__(self, cap):
            self._init_(None if mode=='eval' or mode=='comp' else ktype, vtype, cap)
    if not jit: return TypedHeap
    return nb.experimental.jitclass(fields)(TypedHeap)

def print_heap(arr):
    def print_tree(index, level):
        if index < len(arr):
            print_tree(2 * index + 2, level + 1)  # 先打印右子树
            print('    ' * level + str(arr[index]))  # 打印当前节点
            print_tree(2 * index + 1, level + 1)  # 再打印左子树
    print_tree(0, 0)
    
if __name__ == '__main__':
    # set mode
    IntHeap = TypedHeap(np.int32)
    ints = IntHeap(16)
    ints.push(1)

    # map mode
    IntHeap = TypedHeap(np.int32, np.int32)
    ints = IntHeap(16)
    ints.push(1, 10)

    # eval mode
    IntHeap = TypedHeap(lambda self, x:x, np.int32)
    ints = IntHeap(16)
    ints.push(1)

    # comp mode
    IntHeap = TypedHeap(lambda self, x1, x2: x1-x2, np.int32)
    ints = IntHeap(16)
    ints.push(1)

https://github.com/VectorElectron/structron
This is a set of STL containers I implemented using Numba that can rival C++ in performance. The basic implementation is already complete. However, since JITClass doesn’t support caching, after extensive research, I realized refactoring with structref might be a solution.

The challenge is that my JITClass implementation uses inheritance to support different data types and employs global variables to achieve conditional compilation-like features, enabling four modes: set, map, eval, and comp. Combined with my unfamiliarity with structref, I’ve decided to start with a relatively simpler container - the heap - as an example. I’m posting it here hoping those interested can collaborate on the refactoring. I’ll also share my subsequent refactoring results in this thread.

1 Like

@nelson2005 hi, I believe you’re an expert in this field!

Here’s enough for a start. As an aside, I don’t know if your canonical implementation is intended to keep the data elements in sorted order. It appears that the provided example does not.

Special shout out to @milton for helping on the details, and @DannyWeitekamp as the OG of structrefs.

Comments/questions welcome. As a note, put these bits in their own file and call them from another driver script. Object creation and member function calls work in both jitted and plain-python contexts.

import numba as nb
import numpy as np
from numba.core.extending import overload_method
from numba.core.types import StructRef
from numba.experimental import structref

@structref.register
class SetHeapTypeTemplate(StructRef):
    pass

class SetHeap(structref.StructRefProxy):
    def __new__(cls, size, key):
        return structref.StructRefProxy.__new__(cls, size, key)

    @nb.njit(cache=True)
    def push(self, k):
        self.push(k)

def push_impl(self, k):
    if self.size == len(self.key):
        self.key = np.concatenate((self.key, np.empty(len(self.key), self.key.dtype)))
    i = self.size
    key = self.key

    while i != 0:
        pi = (i - 1) // 2
        br = key[pi] - k
        if br <= 0: break

        key[i] = key[pi]
        i = pi

    key[i] = k
    self.key = key
    self.size += 1

@overload_method(SetHeapTypeTemplate, 'push', inline='always')
def ol_push(self, k):
    return push_impl

structref.define_constructor(SetHeap, SetHeapTypeTemplate, ['size', 'key'])
structref.define_boxing(SetHeapTypeTemplate, SetHeap)

@nb.njit
def print_arr(heap):
    print(heap.key)
1 Like