How can I cache nopython-mode scipy.special.cython_special-functions?

Making this a separate discussion, just to keep track.

I distribute calculations involving scipy.special.cython_special-functions to a number of remote workers, and I think I would benefit from being able to cache the functions on the workerside beforehand.

  • Is this possible? What could be a way to go forward?

Tried:

With the method I use, the function is not cache-able, due to the ctypes pointers in cython_special.py being global variables.

cython_special.py
import ctypes
import scipy
from numba import njit
from numba.extending import get_cython_function_address

addr = get_cython_function_address('scipy.special.cython_special', 'huber')
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)
chuber = functype(addr)

@njit(cache=True)
def nchuber(delta, r):
    return chuber(delta, r)

module.py
from numba import njit
import cython_special as cs

@njit(cache=True)
def call_nchuber(delta, r):
    return cs.nchuber(delta, r)

MRE

Using LocalCluster here, for simplicity. See # 2278 for details.

from distributed import Client, LocalCluster
cluster=LocalCluster()
client=Client(cluster)
import module as m
import cython_special as cs

submitted = client.submit(m.call_nchuber, 1.0, 4.0)
submitted.result()

MRE results in:

NumbaWarning: Cannot cache compiled function "nchuber" as it uses dynamic globals 
(such as ctypes pointers and large global arrays) 

Is it possible to convert the Scipy wrapper into a jitclass (or structref)?
By default, a jitclass does not appear to be pickleable or cachable.
There might be workarounds but I’m not sure if they work in a cluster (see Github issues below).

import ctypes
import numba as nb
from numba.extending import get_cython_function_address

import dask
import dask.distributed  # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
# dask.config.set({'distributed.worker.multiprocessing-method': 'spawn'})
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
# dask.config.set({'distributed.worker.multiprocessing-method': 'fork'})
# from distributed import Client, LocalCluster
from dask.distributed import Client, LocalCluster

# Can we use a Numba object like jitclass instead?
class ScipyFunctionWrapper:
    def __init__(self):
        self.func = self._get_chuber_function()

    def _get_chuber_function(self):
        """
        Obtain the ctypes function pointer to the Scipy function.
        """
        addr = get_cython_function_address('scipy.special.cython_special', 'huber')
        functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)
        chuber = functype(addr)

        # How to cache?
        @nb.njit  # (cache=True) => ValueError('ctypes objects containing pointers cannot be pickled
        def wrapped_chuber(delta, r):
            return chuber(delta, r)
        return wrapped_chuber

    def __call__(self, delta, r):
        """
        Allows the class instance to be called directly to execute the wrapped function.
        """
        return self.func(delta, r)

    def __reduce__(self):
        """
        Method to define how the object should be pickled.
        """
        return (self.__class__, ())

def test_calc():
    chuber = ScipyFunctionWrapper()
    print(f"Result: {chuber(1.0, 4.0)}")
    # Result: 3.5

def test_pickle():
    import pickle
    chuber = ScipyFunctionWrapper()
    pickled_chuber = pickle.dumps(chuber)
    unpickled_chuber = pickle.loads(pickled_chuber)
    print(f"Result: {unpickled_chuber(1.0, 4.0)}")
    # Result: 3.5

# Test local cluster
if __name__ == "__main__":
    try:
        chuber = ScipyFunctionWrapper()
        with LocalCluster() as cluster:
            with Client(cluster) as client:
                submitted = client.submit(chuber, 1.0, 4.0)
                result = submitted.result()
                print(f"Result: {result}")
                # Result: 3.5
    except Exception as e:
        print("An error occurred:", e)

Here is a way to cache the scipy function call but it will probably not be possible to pickle the function like that.
Maybe both ideas can be combined.

from numba.extending import get_cython_function_address, intrinsic
from numba.core import cgutils
from numba.types import f8, i8
from numba import types, njit

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    '''Recover a function from it's signature and address.'''
    func_type = func_type_ref.instance_type
    def codegen(context, builder, sig, args):
        _, addr = args
        sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
        llty = context.get_value_type(types.voidptr)
        addr_ptr = builder.inttoptr(addr, llty)
        sfunc.addr = addr_ptr
        return sfunc._getvalue()
    sig = func_type(func_type_ref, addr)
    return sig, codegen

fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))

@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, func_addr):
    f = _func_from_address(fn_type, func_addr)
    return f(delta, r)

# Example
print(f'result: {call_chuber(1.0, 4.0, fn_addr)}')

Interesting. Nice that you were able to cache it locally!

I see it does get my workers killed unfortunately. (I also tried setting the address-integer as a constant). I think the reason the cluster is reacting this way, is that fn_type is a global variable (?).

Do you think this is something I should bring up in the Dask discourse instead? I might try that.

MRE on "Cluster of workers"
@njit(f8(f8, f8), cache=True)
def call_chuber(delta, r):
    f = _func_from_address(fn_type, 140228704543888)
    return f(delta, r)

submitted = client.submit(call_chuber, 1.0, 4.0)
submitted.result()

Results in

---------------------------------------------------------------------------
KilledWorker                              Traceback (most recent call last)
Cell In[3], line 1
----> 1 submitted.result()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py:281, in Future.result(self, timeout)
    279 if self.status == "error":
    280     typ, exc, tb = result
--> 281     raise exc.with_traceback(tb)
    282 elif self.status == "cancelled":
    283     raise result

KilledWorker: Attempted to run task call_chuber_csO-d89e80a3e21dc6bdc9616a565a839a62 on 3 different workers, but all those workers died while running it. The last worker that attempt to run the task was tcp://127.0.0.1:46605. Inspecting worker logs is often a good next step to diagnose what went wrong. For more information see https://distributed.dask.org/en/stable/killed.html.

Also:
When running locally (in a local notebook-session, not on cluster), I notice it is slightly slower than the non-cached version. But I think the approach is still worth looking into. And those Numba-modules are abit outside of my knowledge p.t. (cgutils, intrinsic) but interested in learning them.

Hey @ofk123,

The Dask discourse will definitely be helpful.
Just to clarify, the idea was to utilize a class-like object to handle processes being terminated by using a “reduce” method. This approach would enable the workers to reconstruct if needed.
The Scipy wrapper function needs to be cached, but it’s crucial to separate the construction phase from the evaluation phase otherwise it will be too slow (as it is for now).

from numba import types, njit
from numba.types import f8, i8
from numba.core import cgutils
from numba.extending import get_cython_function_address, intrinsic

import dask
import dask.distributed  # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
from dask.distributed import Client, LocalCluster

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    '''Recover a function from it's signature and address.'''
    func_type = func_type_ref.instance_type

    def codegen(context, builder, sig, args):
        _, addr = args
        sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
        llty = context.get_value_type(types.voidptr)
        addr_ptr = builder.inttoptr(addr, llty)
        sfunc.addr = addr_ptr
        return sfunc._getvalue()
    sig = func_type(func_type_ref, addr)
    return sig, codegen

fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))

# TODO: Can we separate function generation from evaluation?
@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, fn_addr):
    f = _func_from_address(fn_type, fn_addr)
    return f(delta, r)

# TODO: Can we use a Numba object like structref instead?
class ScipyFunctionWrapper:
    def __init__(self):
        self.fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
        self.func = call_chuber

    def __call__(self, delta, r):
        """Allows the class instance to be called directly to execute the wrapped function."""
        return self.func(delta, r, self.fn_addr)

    def __reduce__(self):
        """Method to define how the object should be pickled."""
        return (self.__class__, ())

# Test local cluster
if __name__ == "__main__":
    try:
        chuber = ScipyFunctionWrapper()
        with LocalCluster() as cluster:
            with Client(cluster) as client:
                submitted = client.submit(chuber, 1.0, 4.0)
                result = submitted.result()
                print(f"Result: {result}")
                # Result: 3.5
    except Exception as e:
        print("An error occurred:", e)

Hi @ofk123

Perhaps you can show the actual code that you are running both locally and on the remote site, rather than a (I assume) simplified version of it. The worker logs might also be helpful.

Based on the somewhat limited information you gave, I assume that with the solution suggested by @Oyibo you simply send the function address to the remote workers. This is literally the physical location of that function in your computer’s memory. You would need to retrieve that address on the remote side. This is also the reason why you cannot cache the ctypes pointer. It is highly unsafe.

Hey @sschaer,

I’ve been experimenting with different caching options for external functions using Numba, and I could need some advice.
Numba offers several methods for loading external functions, including get_cython_function_address (via scipy), ctypes, cffi, WrapperAddressProtocol, and ExternalFunction. I’ve noticed during testing that caching fails with get_cython_function_address, ctypes, and cffi, resulting in a warning: “NumbaWarning: Cannot cache compiled function”. However, when using WrapperAddressProtocol and ExternalFunction, caching seems to work as intended without any warnings, and cache files are generated successfully. However, the WrapperAddressProtocol seems to be slower than the other methods.
Is this the preferable approach?
llvmlite.binding.load_library_permanently(find_library(<library>))
followed by
numba.types.ExternalFunction(<funcname>, ...)
What would you recommend in general when handling external functions using Numba and you want to cache them?

1 Like

Hi @sschaer,
The MRE above is an example of how it is called on a local cluster. On a remote cluster I use:

from dask_gateway import Gateway, GatewayCluster
from distributed import Client
cluster=GatewayCluster()
client=Client(cluster)

import module as m
import cython_special as cs
submitted = client.submit(m.call_nchuber, 1.0, 4.0)
submitted.result()

But it is of course easier to discuss over the real code. In the near future I will be too occupied, but I will try to give it a shot at a later time, and let you know on here.

Can you eleborate how you managed to cache a special funtion using ‘numba.types.ExternalFunction(, …)’?

For example it is not to complicated to create some cython wrapper for full support of special functions. Example I wrote I long time ago

Of course using stack arrays, otherwise the performance will be very bad.

@max9111 Unfortunately, I don’t know where to find the shared library/libraries to load SciPy’s special functions.
I’ve put together some code that uses GSL (GNU Scientific Library) instead of SciPy to cache a special function within Numba jitted function, but I’m not entirely sure if it’s working perfectly as intended. It seems like it does though…
Here’s the code:

from numba.core import types, typing
from llvmlite.binding import load_library_permanently
from ctypes.util import find_library
import scipy.special as sp  # to verify the result

# We use special functions from GSL - GNU Scientific Library
# make sure that GSL is installed
# sudo apt install libgsl-dev
# gsl-config --version
# 2.7.1

# Path to the GSL shared library (adjust path if necessary)
gsl_library_path = find_library('gsl')  # 'libgsl.so.27'
load_library_permanently(gsl_library_path)

# Define the GSL Bessel function signature (gsl_sf_bessel_J0)
c_func_name = 'gsl_sf_bessel_J0'
return_type = types.float64  # The return type of gsl_sf_bessel_J0 is double
arg_type = types.float64     # The argument type of gsl_sf_bessel_J0 is double
c_sig = typing.signature(return_type, arg_type)
c_func = types.ExternalFunction(c_func_name, c_sig)

# Define a Numba-compiled function to call the GSL Bessel function
# Caching should result in two functions
# gsl_special_function_in_numba.compute_bessel_j0_gsl-27.py312.1.nbc
# gsl_special_function_in_numba.compute_bessel_j0_gsl-27.py312.nbi
@njit(cache=True)
def compute_bessel_j0_gsl(x):
    return c_func(x)

# Test value
x = 5.0

# GSL & SciPy results
gsl_result = compute_bessel_j0_gsl(x)
scipy_result = sp.jv(0, x)  # Bessel function of the first kind, order 0

# Output results
print(f"GSL Bessel J0({x}) = {gsl_result}")
print(f"SciPy Bessel J0({x}) = {scipy_result}")
# GSL Bessel J0(5.0) = -0.17759677131433826
# SciPy Bessel J0(5.0) = -0.17759677131433835

@Oyibo

I solved it in the following way.

1. Create a C-DLL which eposes the functions which should be called from Numba

//FileName Wrapper_test.c
// Compile using clang -shared -O3 -IC:\Python\include -LC:\Python\libs Wrapper_test.c -o Wrapper_test.dll
#include "Python.h"


#ifdef _WIN32
#    define API __declspec(dllexport)
#else
#    define API
#endif

struct bint {
  int n;
  int derivative;
};

struct complex128 {
  double real;
  double imag;
};

/* Fetch the address of the given function, as exposed by
   a cython module */
static void * import_cython_function(const char *module_name, const char *function_name)
{
    PyObject *module, *capi, *cobj;
    void *res = NULL;
    const char *capsule_name;

    module = PyImport_ImportModule(module_name);
    if (module == NULL)
        return NULL;
    capi = PyObject_GetAttrString(module, "__pyx_capi__");
    Py_DECREF(module);
    if (capi == NULL)
        return NULL;
    cobj = PyMapping_GetItemString(capi, (char *)function_name);
    Py_DECREF(capi);
    if (cobj == NULL) {
        PyErr_Clear();
        PyErr_Format(PyExc_ValueError,
                     "No function '%s' found in __pyx_capi__ of '%s'",
                     function_name, module_name);
        return NULL;
    }
    /* 2.7+ => Cython exports a PyCapsule */
    capsule_name = PyCapsule_GetName(cobj);
    if (capsule_name != NULL) {
        res = PyCapsule_GetPointer(cobj, capsule_name);
    }
    Py_DECREF(cobj);
    return res;
}
///////////End of standard declarations////////

///////////Define function pointers////////////
double (*cy_voigt_profile)(double,double,double,int)=0;
struct complex128 (*cy_wofz)(struct complex128,int)=0;

//////////Init has to be called before any other function is called//////
API void init(){
  cy_voigt_profile = import_cython_function("scipy.special.cython_special","voigt_profile");
  cy_wofz = import_cython_function("scipy.special.cython_special","wofz");
}

//////////Numba compatible function//////////
API double nb_voigt_profile(double in1,double in2,double in3)
{   
    return cy_voigt_profile(in1,in2,in3,1);
}

API void nb_wofz(double in1_real,double in1_imag,double *out1_real,double *out1_imag)
{   
    struct complex128 c_in1;
    c_in1.real = in1_real;
    c_in1.imag = in1_imag;
    
    struct complex128 cout1 = cy_wofz(c_in1,1);
    
    out1_real[0] = cout1.real;
    out1_imag[0] = cout1.imag;
}

2. Call wrap the functions in Numba

import numba as nb
from numba.core import types, typing
from llvmlite import binding
from numba import types
from numba.extending import intrinsic
from numba.core import cgutils

@intrinsic
def val_to_ptr(typingctx, data):
    def impl(context, builder, signature, args):
        ptr = cgutils.alloca_once_value(builder,args[0])
        return ptr
    sig = types.CPointer(nb.typeof(data).instance_type)(nb.typeof(data).instance_type)
    return sig, impl

@intrinsic
def ptr_to_val(typingctx, data):
    def impl(context, builder, signature, args):
        val = builder.load(args[0])
        return val
    sig = data.dtype(types.CPointer(data.dtype))
    return sig, impl

#####################################################################


binding.load_library_permanently('Wrapper_test.dll')

c_sig = types.void()
nb_init = types.ExternalFunction('init', c_sig)
c_sig = types.double(types.double, types.double, types.double)
nb_voigt_profile = types.ExternalFunction('nb_voigt_profile', c_sig)

c_sig = types.void(types.double, types.double, types.CPointer(types.double), types.CPointer(types.double))
nb_wofz = types.ExternalFunction('nb_wofz', c_sig)


c_sig = types.void(types.CPointer(types.double), types.CPointer(types.double), types.CPointer(types.double), types.CPointer(types.double))
nb_wofz_c = types.ExternalFunction('nb_wofz_c', c_sig)

########Has to be called first#####
nb_init()
###################################
@nb.njit(parallel=False)
def numba_wofz(in1):
    out_real = val_to_ptr(nb.double(0.))
    out_imag = val_to_ptr(nb.double(0.))
    nb_wofz(in1.real, in1.imag, out_real, out_imag)
    
    return ptr_to_val(out_real) + 1j * ptr_to_val(out_imag)

With this method functions which are using complex numbers can be wrapped and caching is possible. Of course don’t foerget to call the init mehtod first.

It would be very interesting to get to a solution where global varibales can be set with some simple Python code, in this case this are function pointer adresses and wrap this functions in a more direct way to Numba.
Returning structures is platfrom dependend, although an implementation for one major platform (x86) would be very interesting.

For example the bitcode vom Clang looks like

source_filename = "Wrapper_test.c"
target datalayout = "e-m:w-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-pc-windows-msvc19.35.32216"

%struct.complex128 = type { double, double }

$"??_C@_04IIOAPFIG@wofz?$AA@" = comdat any
@cy_wofz = dso_local local_unnamed_addr global void (%struct.complex128*, %struct.complex128*, i32)* null, align 8

@"??_C@_04IIOAPFIG@wofz?$AA@" = linkonce_odr dso_local unnamed_addr constant [5 x i8] c"wofz\00", comdat, align 1
%152 = tail call fastcc i8* @import_cython_function(i8* noundef getelementptr inbounds ([5 x i8], [5 x i8]* @"??_C@_04IIOAPFIG@wofz?$AA@", i64 0, i64 0))


store i8* %152, i8** bitcast (void (%struct.complex128*, %struct.complex128*, i32)** @cy_wofz to i8**), align 8, !tbaa !5
  
; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: argmemonly mustprogress nofree nounwind willreturn
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #2

; Function Attrs: argmemonly mustprogress nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: nounwind uwtable
define dso_local dllexport void @nb_wofz_c(%struct.complex128* noalias sret(%struct.complex128) align 8 %0, %struct.complex128* nocapture noundef readonly %1) local_unnamed_addr #0 {
  %3 = alloca %struct.complex128, align 8
  %4 = load void (%struct.complex128*, %struct.complex128*, i32)*, void (%struct.complex128*, %struct.complex128*, i32)** @cy_wofz, align 8, !tbaa !5
  %5 = bitcast %struct.complex128* %3 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull %5) #4
  %6 = bitcast %struct.complex128* %1 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(16) %5, i8* noundef nonnull align 8 dereferenceable(16) %6, i64 16, i1 false), !tbaa.struct !12
  call void %4(%struct.complex128* sret(%struct.complex128) align 8 %0, %struct.complex128* noundef nonnull %3, i32 noundef 1) #4
  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %5) #4
  ret void
}
1 Like

@max9111 thank you for the example.
If it were possible to expose the low-level special functions from scipy.special.cython_special as C symbols in a shared library (e.g., libscipy_special.so), similar to how libraries like GSL (libgsl.so) and LAPACK (libopenblas.so) expose their functions, it would enable low-level access to SciPy’s special functions directly within Numba-compiled functions. This approach would also allow us to use Numba’s caching capabilities.
Unfortunately, this is not the case at the moment and you need compiled C or Cython code as glue.