Mutual Recursion

I’m attempting to port ntdr.c to njitted python. There are a couple of mutually recursive functions that are tripping up the compiler. The program works in regular python but in nopython mode I get a cascading signature error like below. I’ve tried various bits, like explicitly specifying function signatures with @njit(‘f8(f8)’) to no avail. I assume I’m making a fundamental mistake and hoping for a pointer about where.

Thanks

numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function ndtr at 0x0000026EEB7662F0>) found for signature:

ndtr(float64)

There are 2 candidate implementations:

  • Of which 2 did not match due to:
    Overload in function ‘register_jitable..wrap..ov_wrap’: File: numba\core\extending.py: Line 150.
    With argument(s): ‘(float64)’:
    Rejected as the implementation raised a specific error:
    TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    No implementation of function Function(<function erf at 0x0000026EEB766950>) found for signature:

erf(float64)

import math
from numba import njit
import numpy as np
from numba.extending import register_jitable

MAXLOG = 7.09782712893383996843E2
M_PI = 3.1415926535897932384626433832795

P = [
    2.46196981473530512524E-10,
    5.64189564831068821977E-1,
    7.46321056442269912687E0,
    4.86371970985681366614E1,
    1.96520832956077098242E2,
    5.26445194995477358631E2,
    9.34528527171957607540E2,
    1.02755188689515710272E3,
    5.57535335369399327526E2
]

P_rev = tuple(reversed(P))

Q = [
    #
    # 1.00000000000000000000E0,
    #
    1.32281951154744992508E1,
    8.67072140885989742329E1,
    3.54937778887819891062E2,
    9.75708501743205489753E2,
    1.82390916687909736289E3,
    2.24633760818710981792E3,
    1.65666309194161350182E3,
    5.57535340817727675546E2
]

Q_rev = tuple(reversed(Q))


R = [
    5.64189583547755073984E-1,
    1.27536670759978104416E0,
    5.01905042251180477414E0,
    6.16021097993053585195E0,
    7.40974269950448939160E0,
    2.97886665372100240670E0
]

R_rev = tuple(reversed(R))

S = [
    #
    # 1.00000000000000000000E0,
    #
    2.26052863220117276590E0,
    9.39603524938001434673E0,
    1.20489539808096656605E1,
    1.70814450747565897222E1,
    9.60896809063285878198E0,
    3.36907645100081516050E0
]
S_rev = tuple(reversed(S))

T = [
    9.60497373987051638749E0,
    9.00260197203842689217E1,
    2.23200534594684319226E3,
    7.00332514112805075473E3,
    5.55923013010394962768E4
]
T_rev = tuple(reversed(T))

U = [
    #
    # 1.00000000000000000000E0,
    #
    3.35617141647503099647E1,
    5.21357949780152679795E2,
    4.59432382970980127987E3,
    2.26290000613890934246E4,
    4.92673942608635921086E4
]
U_rev = tuple(reversed(U))

UTHRESH = 37.519379347
NPY_SQRT1_2 = 0.707106781186547524400844362104849039

@njit
def polevl_rev(x, coefs):
    ans = 0
    x_power = 1
    for coef in coefs:
        ans += coef * x_power
        x_power = x_power * x
    return ans


@njit
def p1evl_rev(x, coefs):
    ans = 0
    x_power = 1
    for coef in coefs:
        ans += coef * x_power
        x_power = x_power * x
    return ans + x_power


@register_jitable
def ndtr(a):
    if np.isnan(a):
        return np.nan

    x = a * NPY_SQRT1_2
    z = np.fabs(x)

    if z < NPY_SQRT1_2:
        y = 0.5 + 0.5 * erf(x)
    else:
        y = 0.5 * erfc(z)

        if x > 0:
            y = 1.0 - y

    return y


@register_jitable
def under(a):
    return 2.0 if a < 0 else 0.0


@register_jitable
def erf_positive(x):
    """
    erf for positive number.  attempting to reduce recursion
    :param x:
    :return:
    """
    if x > 1.0:
        return 1.0 - erfc(x)

    z = x * x
    y = x * polevl_rev(z, T_rev) / p1evl_rev(z, U_rev)
    return y


@register_jitable
def erf(x):
    if np.isnan(x):
        return np.nan

    if x < 0.0:
        return -erf_positive(-x)

    return erf_positive(x)

@register_jitable
def erfc(a):
    if np.isnan(a):
        return np.nan

    if a < 0.0:
        x = -a
    else:
        x = a

    if x < 1.0:
        return 1.0 - erf(a)

    z = -a * a

    if z < -MAXLOG:
        return under(a)

    z = math.exp(z)

    if x < 8.0:
        p = polevl_rev(x, P_rev)
        q = p1evl_rev(x, Q_rev)
    else:
        p = polevl_rev(x, R_rev)
        q = p1evl_rev(x, S_rev)
    y = (z * p) / q

    if a < 0:
        y = 2.0 - y

    if y == 0.0:
        return under(a)

    return y


#@njit  # program works until I uncomment this line
def jit_ndtr(x):
    return ndtr(x)

jit_ndtr(2.0)

Hi @nelson2005,

I think scipy.special exports ntdr as a cython export and Numba can bind to this directly. This is the API to use: https://numba.readthedocs.io/en/stable/extending/high-level.html#importing-cython-functions and it looks like you want the __pyx_fuse_1ndtr__ version for the double type:

In [28]: import scipy

In [29]: [x for x in scipy.special.cython_special.__pyx_capi__.keys() if 'ndtr' in x]
Out[29]: 
['chndtr',
 'chndtridf',
 'chndtrinc',
 'chndtrix',
 'ndtri',
 '__pyx_fuse_0log_ndtr',
 '__pyx_fuse_1log_ndtr',
 '__pyx_fuse_0ndtr',
 '__pyx_fuse_1ndtr']

In [30]: scipy.special.cython_special.__pyx_capi__['__pyx_fuse_0ndtr']
Out[30]: <capsule object "__pyx_t_double_complex (__pyx_t_double_complex, int __pyx_skip_dispatch)" at 0x7fb05f2143f0>

In [31]: scipy.special.cython_special.__pyx_capi__['__pyx_fuse_1ndtr']
Out[31]: <capsule object "double (double, int __pyx_skip_dispatch)" at 0x7fb05f214420>

RE: the OP… I think the code is tripping up the compiler somewhere, seems like the recursion might not be tracked correctly, any chance you could open an issue please?

Also, as math.erf and math.erfc are supported, I think this sort of thing could work for you:

import math
from numba import njit
import numpy as np
from numba.extending import register_jitable

NPY_SQRT1_2 = 0.707106781186547524400844362104849039

# Translation of:
# https://github.com/scipy/scipy/blob/a8b66ec40e007cf230761305fa483370328b5f09/scipy/special/cephes/ndtr.c#L201-L224
# License: https://github.com/scipy/scipy/blob/a8b66ec40e007cf230761305fa483370328b5f09/LICENSES_bundled.txt#L44-L78

@register_jitable
def ndtr(a):
    if np.isnan(a):
        return np.nan

    x = a * NPY_SQRT1_2
    z = np.fabs(x)

    if z < NPY_SQRT1_2:
        y = 0.5 + 0.5 * math.erf(x)
    else:
        y = 0.5 * math.erfc(z)

        if x > 0:
            y = 1.0 - y

    return y


@njit
def jit_ndtr(x):
    return ndtr(x)

print(jit_ndtr(2.0))
print(jit_ndtr.py_func(2.0))
import scipy.special
print(scipy.special.ndtr(2.0))

Hope this helps?

Thanks for the tips! We’re currently importing/using the cython export and were hoping to get a bit more speed out of (hopefully) getting it all inlined with numba. Regarding the math.erf, I was aware of that but was hoping to go with the same implementation as the scipy ndtr… the math.erf seems to use a different algorithm than scipy

Issue opened here

Adding another link as reference