Generated_jit with default arguments

I’m trying to create a flexible function with default arguments as below.

@njit(cache=True)
def interp_decay(x0, x_list, y_list, intercept_limit=0.0, slope_limit=0.0, lower_extrap=False):
    # Make a decay extrapolation
    slope_at_top = (y_list[-1] - y_list[-2]) / (x_list[-1] - x_list[-2])
    level_diff = intercept_limit + slope_limit * x_list[-1] - y_list[-1]
    slope_diff = slope_limit - slope_at_top

    decay_extrap_A = level_diff
    decay_extrap_B = -slope_diff / level_diff
    intercept_limit = intercept_limit
    slope_limit = slope_limit

    i = np.maximum(np.searchsorted(x_list[:-1], x0), 1)
    alpha = (x0 - x_list[i - 1]) / (x_list[i] - x_list[i - 1])
    y0 = (1.0 - alpha) * y_list[i - 1] + alpha * y_list[i]

    if not lower_extrap:
        below_lower_bound = x0 < x_list[0]
        y0[below_lower_bound] = np.nan

    above_upper_bound = x0 > x_list[-1]
    x_temp = x0[above_upper_bound] - x_list[-1]

    y0[above_upper_bound] = (
            intercept_limit
            + slope_limit * x0[above_upper_bound]
            - decay_extrap_A * np.exp(-decay_extrap_B * x_temp)
    )

    return y0


@njit(cache=True)
def interp_linear(x0, x_list, y_list, intercept_limit=None, slope_limit=None, lower_extrap=False):
    i = np.maximum(np.searchsorted(x_list[:-1], x0), 1)
    alpha = (x0 - x_list[i - 1]) / (x_list[i] - x_list[i - 1])
    y0 = (1.0 - alpha) * y_list[i - 1] + alpha * y_list[i]

    if not lower_extrap:
        below_lower_bound = x0 < x_list[0]
        y0[below_lower_bound] = np.nan

    return y0


@generated_jit(nopython=True, cache=True)
def LinearInterpFast(x0, x_list, y_list, intercept_limit=None, slope_limit=None, lower_extrap=False):
    if isinstance(intercept_limit, types.NoneType) and isinstance(slope_limit, types.NoneType):
        return interp_linear
    else:
        return interp_decay

The above creates the following errors:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Internal error at <numba.core.typeinfer.CallConstraint object at 0x000001EF89078940>.
generated implementation CPUDispatcher(<function interp_decay at 0x000001EF88ADF9D0>) should be compatible with signature '(x0, x_list, y_list, intercept_limit=None, slope_limit=None, lower_extrap=False)', but has signature '(x0, x_list, y_list, intercept_limit=0.0, slope_limit=0.0, lower_extrap=False)'
During: resolving callee type: type(CPUDispatcher(<function LinearInterpFast at 0x000001EF88ADFF70>))
During: typing of call at [*elided*] (189)

Enable logging at debug level for details.

File "[*elided*]", line 189:
def _solveConsIndShockLinearNumba(
    <source elided>
    )
    cFuncNextUnc = LinearInterpFast(mNrmNext.flatten(), mNrmUnc, cNrmUnc)
    ^

I would like to be able to call LinearInterpFast as:
LinearInterpFast(x0,x,y) -> basic linear interpolation
LinearInterpFast(x0,x,y,intercept,slope) -> linear interpolation with exponential decay extrapolation

Any suggestions as to what I’m doing wrong?

1 Like

hi, have you tried using normal njit instead of generated_jit, with an if to distinguish the cases as you would in normal python?

@njit
def LinearInterpFast(x0, x_list, y_list, intercept_limit=None, slope_limit=None, lower_extrap=False):
    if intercept_limit is None and slope_limit is None:
        return interp_linear(....args...)
    else:
        return interp_decay(...args...)
2 Likes

Yes, that works! I guess i’m still not 100% sure how njit and generated_jit work…

Thank you so much!