Pass list of DIFFERENT jitclasses or jitclass methods to a jit'ed function

I am trying make a jitted function which takes a list of instances of different jitclasses, or a method of these jitclasses so it can call these methods in a chain. I found the threads below, but they all fall short:

  • This workaround of chaining functions does not work on jitclass methods:
    [using a list/tuple of jitted functions · Issue #2542 · numba/numba]

  • This one works for a list of jit functions with the same signature, but not jitclass methods.
    [Typed list of jitted functions in jitclass - Numba / Support: How do I do …? - Numba Discussion]

  • This only works for a list of the instances of the same jitclass, not different classes.
    [How do I create a jitclass that takes a list of jitclass objects? - Numba / Support: How do I do …? - Numba Discussion]

So is there no way to do this?

I suspect the answer is no, so let me explain what I am trying to do in case there is an entirely different solution.

I am trying to build a library that can simulate a chain of digital filters. Each filter is basically a system that has internal state, and at each simulation step, it takes in a number, does math on it using the internal state, and outputs a number (and updates the internal state).
For each filter type, say FilterA and FilterB, I defined a jitclass with a self.states, and a step() method that receives and returns a float. The user of this library should be able to make a chain of filters, e.g. [FilterA, FilterB, FilterA], and then call a jitted function which iterates through an array of inputs samples, putting the sample through the chain of filters:

@jit
def simulate(in_samples, filter_chain):
  for input in in_samples:
    out_samples = []
    output = input
    for filt in filter_chain:
      output = filt.step(output)
    out_samples.append(output)

So you can see that I need to pass to simulate() either the filter_chain (a list of different jitclass instances), or the FilterX.step() methods (a list of methods of different jitclass instances).

I tried to avoid jitclass by using generators or closures but ran into numba limitation there as well (lack of send() for generators, inability to write to a nonlocal variable inside closure).

Is there another way to do this that I am not seeing?

This is definitely a cases where dynamic behavior in Python can’t be replicated in exactly the same way in a compiled implementation. If you were using C++ you could do this with polymorphic types, ‘step’ would be a virtual method that differs in implementation between classes. To my knowledge numba doesn’t have this so we’ll need to just replicate the basic idea. The key insight is that a method is just a function whose first argument is “self”. So for instance:

If .state is the same kind of thing in each class, like maybe a numpy array, then you could write multiple step() functions that will have the same signature and you can just have a list of tuples each with a state object and a step function cast to a first-class function (which I think is maybe explained in one of the threads above). In this case you could forget about jitclasses entirely, which I would advocate for because they become a real pain in larger projects since they prevent numba from speeding up compile times by caching compiled code from previous runs.

If states really are different kinds of things between different classes then you’re getting into a situation where you need to upcast and downcast classes or cast them from pointers. I could share special intrinsic functions that could help you go in that direction, but it’s very non-standard and opens up unsafe possibilities where any mistake in your code will give you a nice hard-to-debug segfault.

Hi Danny,

I was hoping you would see this :slight_smile:

Can you please show an example of the 1st option?

The state can always be a numpy array, say 1D. But the length of the array varies. But there are also other parameters that I currently pass to __init__ that are unique to that jitclass.

For example
a=FilterA(coeff=(1,2,3), order=4)
would initialize a.states to a order length array.
coeff are used inside step() along with states but not modified.

b=FilterB(x=((1, 3),(3,5)), y=(3,3,1), order=4)
could initialize states to a order*2 length array.

Maybe something like this:

import numpy as np
from numba import njit
from numba import types
from numba.types import i8, f8
from numba.typed import List


# --------------
# : FilterA
@njit(cache=True)
def FilterA_step(state, sample):
    state += sample[0]
    return sample*state[0]

@njit(cache=True)
def FilterA(a, order):
    state = np.empty(order, dtype=np.float64)
    state[:order] = a
    return (state, FilterA_step)

# --------------
# : FilterB
@njit(cache=True)
def FilterB_step(state, sample):
    state *= sample[0]
    return sample+state[0]

@njit(cache=True)
def FilterB(a, b, order):
    state = np.empty(order, dtype=np.float64)
    state[:order] = a
    state[order:order*2] = b
    return (state, FilterB_step)


# --------------
# : simulate
@njit(cache=True)
def simulate(in_samples, filter_chain):
    out_samples = np.empty_like(in_samples)

    for i, sample in enumerate(in_samples):
        output = sample
        for state, step_func in filter_chain:
            output = step_func(state, output)
        out_samples[i, :] = output
    return out_samples

filter_step_sig = f8[:](f8[:], f8[:])
filter_step_t = types.FunctionType(filter_step_sig)

# Make filters
filters = List.empty_list(types.Tuple((f8[:], filter_step_t)))
filters.append(FilterA(1, 4))
filters.append(FilterB(1, 2, 4))


# Make Data
data = np.random.uniform(0.0, 1.0, size=(10, 4))
print(data)

# Apply Simulate
filtered_data = simulate(data, filters)
print(filtered_data)

This line List.empty_list(types.Tuple((f8[:], filter_step_t))) forces the filter list to take a numpy array of floats and a generic first-class function. Without this each second append() will throw an error because the first append() will have the list induce its type as including a literal function pointer (a pointer to a particular function, not a generic function with a particular signature). Unfortunately, I’m not sure how to make the syntax nicer than this, I don’t know of a simple way to allow it to be written like [FilterA(…), FilterB(…)], without adding in lots of non-standard hacks.

Thank you for writing that up! I see what you mean now.
I think this could get the job done, but I do get this warning running the example:
NumbaWarning: Cannot cache compiled function “FilterA” as it uses dynamic globals (such as ctypes pointers and large global arrays)
Do you get the same? It would be nice to get the cache working here if I am giving up on jitclass.

It looks like there has been a regression in functionality between the version I was using 0.58.1 and the newest version 0.60.0. It seems that function references are no longer cache-safe. It’s a little bit less pretty, but you could do this instead:

...
@njit(cache=True)
def FilterA(a, order):
    state = np.empty(order, dtype=np.float64)
    state[:order] = a
    return state
...
@njit(cache=True)
def FilterB(a, b, order):
    state = np.empty(order, dtype=np.float64)
    state[:order] = a
    state[order:order*2] = b
    return state
...

filters = List.empty_list(types.Tuple((f8[:], filter_step_t)))
filters.append((FilterA(1, 4), FilterA_step))
filters.append((FilterB(1, 2, 4), FilterB_step))

It might be worth mentioning in an issue on GitHub if you have the time. I haven’t been keeping close tabs on the change logs lately so I couldn’t tell you if this was a purposeful change in behavior or a regression (a dev mistake).

Sure, will file a bug on this.
This solution ended up a bit more clunky that I could stomach at the end, mostly because the ‘step’ function for different Filter objects needs access to not just ‘state’ but also the run-time constants like b and a and order in the above example, so you end up needing a non-homogenous container of sort like a python dict that can hold scalers, 1D and 2D arrays which Numba does not support. Putting everything in say a 2D array (i.e. scalers becoming a 1x1 matrix) would work I guess, but it makes the code pretty ugly.
For my problem, I resorted to a different approach which is less performant. But thanks for the help. Hopefully this thread is useful to someone.