Hi,
I am working on stateful recursive functions, e.g. exponential moving average or first order IIR filter. I’d like to build them with guvectorize such that I can also use them with dask and get all the broadcasting semantics. My problem is that all of these functions have two outputs, the signal and the last state, such that one can pass the last state to the next function call. Yet, it seems the state is never written to the output for my numba implementation. Here is a minimal example which compares against a pure numpy implementation.
def ema_numpy(x, state, alpha):
y = np.empty_like(x)
N = x.shape[-1]
for i in range(0, N):
y[..., i] = state + (x[..., i] - state) * alpha
state = y[..., i]
final_state = state
return y, final_state
@guvectorize(
[
"void(f8[:], f8, f8, f8[:], f8)",
"void(f4[:], f4, f4, f4[:], f4)",
],
"(k), (), () -> (k), ()",
nopython=True,
target="cpu",
)
def ema_numba(x, state, alpha, y, final_state):
N = x.shape[-1]
for i in range(0, N):
y[..., i] = state + (x[..., i] - state) * alpha
state = y[..., i]
final_state = state # this line seems not to do anything
init_state = 0.0
alpha = 0.01
np.random.seed(1337)
x = np.random.normal(size=(12,))
x1 = x[:6]
x2 = x[6:]
rmse = lambda x,y: np.sqrt(np.mean(np.abs(x-y)**2))
# compare running stateless
np_out, _ = ema_numpy(x, init_state, alpha)
numba_out, _ = ema_numba(x, init_state, alpha)
print("rmse stateless:", rmse(np_out, numba_out.squeeze()))
np.testing.assert_almost_equal(np_out, numba_out.squeeze())
# run numpy stateful implementation
np_out1, state = ema_numpy(x1, init_state, alpha)
np_out2, _ = ema_numpy(x2, state, alpha)
print("init_state", init_state, "state", state)
print("rmse numpy stateful:", rmse(np_out, np.concatenate((np_out1, np_out2))))
np.testing.assert_almost_equal(np_out, np.concatenate((np_out1, np_out2)))
# run numba stateful implementation
numba_out1, state = ema_numba(x1, init_state, alpha)
numba_out2, _ = ema_numba(x2, state, alpha)
print("init_state", init_state, "state", state)
print("rmse numba stateful:", rmse(np_out, np.concatenate((numba_out1, numba_out2))))
np.testing.assert_almost_equal(np_out, np.concatenate((numba_out1, numba_out2)))
Output:
stateless
rmse: 0.0
numpy stateful
init_state 0.0 state -0.04978762936563727
rmse: 0.0
numba stateful
init_state 0.0 state 0.0
rmse: 0.03399832826443144
As one can see, for the numba implementation the init_state == state, which is strange as the state should change.
Thanks!