Generic classes in Numba

Hi everyone!
Is there a way to create a generic class in Numba? I created a custom Numpy array list container class but I have to specify the exact type it will be used with (Transaction) in @jitclass, and I can only specify one. I’d like to be able to have multiple instances of the class with different types.

Here’s the code:

Transaction = np.dtype([('type', np.int8), ('amount', np.float32), ('price', np.float32), ('time', np.float64)])
 
@jitclass([('capacity', nb.types.int32), ('length', nb.types.int32), ('array', nb.from_dtype(Transaction)[:])])
class NumpyArrayList():
  def __init__(self, capacity):
    self.capacity = capacity
    self.length = 0
    self.array = np.empty(capacity, dtype=Transaction)
 
  def append(self, element):
    if self.length >= self.capacity:
      new_size = self.capacity * 2
 
      new_array = np.empty(new_size, dtype=self.array.dtype)
      new_array[:self.length] = self.array
 
      self.capacity = new_size
 
      self.array = new_array
 
    self.array[self.length].type = element[0]
    self.array[self.length].amount = element[1]
    self.array[self.length].price = element[2]
    self.array[self.length].time = element[3]
 
    self.length += 1
 
  def get_np_array(self):
    return self.array[:self.length]

in case the gitter history is lost:

uchytilc @uchytilc Feb 22 19:40
You should be able to put the desired python class into a factory function like so

def factory(capacity = nb.types.int32,
            length = nb.types.int32,
            array = nb.from_dtype(Transaction)):
  return jitclass([('capacity', capacity), ('length', length), ('array', array[:])])(NumpyArrayList)

luk-f-a @luk-f-a Feb 22 21:08

add lru_cache to save compilation time on repeated types

you can probably get fancy, overload __getitem__ or class_getitem and be able to write Factory[int32]() if you like that syntax.

1 Like