Hi @nelson2005,
I think scipy.special
exports ntdr
as a cython export and Numba can bind to this directly. This is the API to use: https://numba.readthedocs.io/en/stable/extending/high-level.html#importing-cython-functions and it looks like you want the __pyx_fuse_1ndtr__
version for the double
type:
In [28]: import scipy
In [29]: [x for x in scipy.special.cython_special.__pyx_capi__.keys() if 'ndtr' in x]
Out[29]:
['chndtr',
'chndtridf',
'chndtrinc',
'chndtrix',
'ndtri',
'__pyx_fuse_0log_ndtr',
'__pyx_fuse_1log_ndtr',
'__pyx_fuse_0ndtr',
'__pyx_fuse_1ndtr']
In [30]: scipy.special.cython_special.__pyx_capi__['__pyx_fuse_0ndtr']
Out[30]: <capsule object "__pyx_t_double_complex (__pyx_t_double_complex, int __pyx_skip_dispatch)" at 0x7fb05f2143f0>
In [31]: scipy.special.cython_special.__pyx_capi__['__pyx_fuse_1ndtr']
Out[31]: <capsule object "double (double, int __pyx_skip_dispatch)" at 0x7fb05f214420>
RE: the OP… I think the code is tripping up the compiler somewhere, seems like the recursion might not be tracked correctly, any chance you could open an issue please?
Also, as math.erf
and math.erfc
are supported, I think this sort of thing could work for you:
import math
from numba import njit
import numpy as np
from numba.extending import register_jitable
NPY_SQRT1_2 = 0.707106781186547524400844362104849039
# Translation of:
# https://github.com/scipy/scipy/blob/a8b66ec40e007cf230761305fa483370328b5f09/scipy/special/cephes/ndtr.c#L201-L224
# License: https://github.com/scipy/scipy/blob/a8b66ec40e007cf230761305fa483370328b5f09/LICENSES_bundled.txt#L44-L78
@register_jitable
def ndtr(a):
if np.isnan(a):
return np.nan
x = a * NPY_SQRT1_2
z = np.fabs(x)
if z < NPY_SQRT1_2:
y = 0.5 + 0.5 * math.erf(x)
else:
y = 0.5 * math.erfc(z)
if x > 0:
y = 1.0 - y
return y
@njit
def jit_ndtr(x):
return ndtr(x)
print(jit_ndtr(2.0))
print(jit_ndtr.py_func(2.0))
import scipy.special
print(scipy.special.ndtr(2.0))
Hope this helps?