Using a stencil to reduce 3D to 2D array

I’m trying to reduce a 3D array to 2D where the output value for the 0, 0 element (relative index) of the 2D array is calculated from the difference to each other kernel element i, j , i.e. arr[h, 0, 0] - arr[h, i, j], summed over the h elements and assigned to 2D array out_arr[0, 0].

This simple in a jitted function but as a learning exercise I am also trying to implement as a stencil.

Simplified jitted version:

import numpy as np
from numba import njit, stencil

@njit()
def test_jit(arr, shape=(3, 3)):
    z, y, x = arr.shape 
    out_arr = np.zeros((y, x))  

    for i in range(shape[0], y-shape[0]):
        for j in range(shape[1], x-shape[1]):
            ij = 0
            for i_off in range(-shape[0], shape[0]+1):
                for j_off in range(-shape[1], shape[1]+1):
                    if (i_off, j_off) == (0, 0):  # don't calc arr[i,j] - arr[i,j]
                        continue
                    ij += np.exp(-np.sum(np.abs(arr[:, i, j] - arr[:, i+i_off, j+j_off])))
            out_arr[i, j] = ij
    return out_arr

in_arr = np.random.randint(0,1000,size=(5, 256, 256))/1000
outarr = test_jit(in_arr)

However the stencil version fails, as I think it expects a 3D output:

import numpy as np
from numba import njit, stencil

def test_stencil(arr, shape=(3, 3)):
    z, y, x = arr.shape 
    out_arr = np.zeros((y, x))  

    @stencil(
        neighborhood=((-z, z), (-y, y), (-x, x))
    )
    def _test(arr):
        val = 0

        for i in range(-y, y + 1):
            for j in range(-x, x + 1):
                if (i, j) == (0, 0):  # don't calc arr[i,j] - arr[i,j]
                    continue

                ij = 0
                for h in range(-z, z + 1):
                    ij += np.abs(arr[h, 0, 0] - arr[h, i, j])

                val += np.exp(-ij)

        return val
    
    _test(arr, out=out_arr)
    return out_arr


in_arr = np.random.randint(0,1000,size=(5, 256, 256))/1000
outarr = test_stencil(in_arr)
print(outarr)

Exception:

No implementation of function Function(<built-in function setitem>) found for signature:

 >>> setitem(array(float64, 2d, C), UniTuple(int64 x 3), float64)

There are 16 candidate implementations:
    - Of which 14 did not match due to:
    Overload of function 'setitem': File: <numerous>: Line N/A.
      With argument(s): '(array(float64, 2d, C), UniTuple(int64 x 3), float64)':
     No match.
    - Of which 2 did not match due to:
    Overload in function 'SetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 176.
      With argument(s): '(array(float64, 2d, C), UniTuple(int64 x 3), float64)':
     Rejected as the implementation raised a specific error:
       NumbaTypeError: cannot index array(float64, 2d, C) with 3 indices: UniTuple(int64 x 3)

How would I implement this 3D->2D reduction in a numba stencil?