I’m attempting to port ntdr.c to njitted python. There are a couple of mutually recursive functions that are tripping up the compiler. The program works in regular python but in nopython mode I get a cascading signature error like below. I’ve tried various bits, like explicitly specifying function signatures with @njit(‘f8(f8)’) to no avail. I assume I’m making a fundamental mistake and hoping for a pointer about where.
Thanks
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function ndtr at 0x0000026EEB7662F0>) found for signature:ndtr(float64)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function ‘register_jitable..wrap..ov_wrap’: File: numba\core\extending.py: Line 150.
With argument(s): ‘(float64)’:
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function erf at 0x0000026EEB766950>) found for signature:erf(float64)
import math
from numba import njit
import numpy as np
from numba.extending import register_jitable
MAXLOG = 7.09782712893383996843E2
M_PI = 3.1415926535897932384626433832795
P = [
2.46196981473530512524E-10,
5.64189564831068821977E-1,
7.46321056442269912687E0,
4.86371970985681366614E1,
1.96520832956077098242E2,
5.26445194995477358631E2,
9.34528527171957607540E2,
1.02755188689515710272E3,
5.57535335369399327526E2
]
P_rev = tuple(reversed(P))
Q = [
#
# 1.00000000000000000000E0,
#
1.32281951154744992508E1,
8.67072140885989742329E1,
3.54937778887819891062E2,
9.75708501743205489753E2,
1.82390916687909736289E3,
2.24633760818710981792E3,
1.65666309194161350182E3,
5.57535340817727675546E2
]
Q_rev = tuple(reversed(Q))
R = [
5.64189583547755073984E-1,
1.27536670759978104416E0,
5.01905042251180477414E0,
6.16021097993053585195E0,
7.40974269950448939160E0,
2.97886665372100240670E0
]
R_rev = tuple(reversed(R))
S = [
#
# 1.00000000000000000000E0,
#
2.26052863220117276590E0,
9.39603524938001434673E0,
1.20489539808096656605E1,
1.70814450747565897222E1,
9.60896809063285878198E0,
3.36907645100081516050E0
]
S_rev = tuple(reversed(S))
T = [
9.60497373987051638749E0,
9.00260197203842689217E1,
2.23200534594684319226E3,
7.00332514112805075473E3,
5.55923013010394962768E4
]
T_rev = tuple(reversed(T))
U = [
#
# 1.00000000000000000000E0,
#
3.35617141647503099647E1,
5.21357949780152679795E2,
4.59432382970980127987E3,
2.26290000613890934246E4,
4.92673942608635921086E4
]
U_rev = tuple(reversed(U))
UTHRESH = 37.519379347
NPY_SQRT1_2 = 0.707106781186547524400844362104849039
@njit
def polevl_rev(x, coefs):
ans = 0
x_power = 1
for coef in coefs:
ans += coef * x_power
x_power = x_power * x
return ans
@njit
def p1evl_rev(x, coefs):
ans = 0
x_power = 1
for coef in coefs:
ans += coef * x_power
x_power = x_power * x
return ans + x_power
@register_jitable
def ndtr(a):
if np.isnan(a):
return np.nan
x = a * NPY_SQRT1_2
z = np.fabs(x)
if z < NPY_SQRT1_2:
y = 0.5 + 0.5 * erf(x)
else:
y = 0.5 * erfc(z)
if x > 0:
y = 1.0 - y
return y
@register_jitable
def under(a):
return 2.0 if a < 0 else 0.0
@register_jitable
def erf_positive(x):
"""
erf for positive number. attempting to reduce recursion
:param x:
:return:
"""
if x > 1.0:
return 1.0 - erfc(x)
z = x * x
y = x * polevl_rev(z, T_rev) / p1evl_rev(z, U_rev)
return y
@register_jitable
def erf(x):
if np.isnan(x):
return np.nan
if x < 0.0:
return -erf_positive(-x)
return erf_positive(x)
@register_jitable
def erfc(a):
if np.isnan(a):
return np.nan
if a < 0.0:
x = -a
else:
x = a
if x < 1.0:
return 1.0 - erf(a)
z = -a * a
if z < -MAXLOG:
return under(a)
z = math.exp(z)
if x < 8.0:
p = polevl_rev(x, P_rev)
q = p1evl_rev(x, Q_rev)
else:
p = polevl_rev(x, R_rev)
q = p1evl_rev(x, S_rev)
y = (z * p) / q
if a < 0:
y = 2.0 - y
if y == 0.0:
return under(a)
return y
#@njit # program works until I uncomment this line
def jit_ndtr(x):
return ndtr(x)
jit_ndtr(2.0)