How can I cache nopython-mode scipy.special.cython_special-functions?

Making this a separate discussion, just to keep track.

I distribute calculations involving scipy.special.cython_special-functions to a number of remote workers, and I think I would benefit from being able to cache the functions on the workerside beforehand.

  • Is this possible? What could be a way to go forward?

Tried:

With the method I use, the function is not cache-able, due to the ctypes pointers in cython_special.py being global variables.

cython_special.py
import ctypes
import scipy
from numba import njit
from numba.extending import get_cython_function_address

addr = get_cython_function_address('scipy.special.cython_special', 'huber')
functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)
chuber = functype(addr)

@njit(cache=True)
def nchuber(delta, r):
    return chuber(delta, r)

module.py
from numba import njit
import cython_special as cs

@njit(cache=True)
def call_nchuber(delta, r):
    return cs.nchuber(delta, r)

MRE

Using LocalCluster here, for simplicity. See # 2278 for details.

from distributed import Client, LocalCluster
cluster=LocalCluster()
client=Client(cluster)
import module as m
import cython_special as cs

submitted = client.submit(m.call_nchuber, 1.0, 4.0)
submitted.result()

MRE results in:

NumbaWarning: Cannot cache compiled function "nchuber" as it uses dynamic globals 
(such as ctypes pointers and large global arrays) 

Is it possible to convert the Scipy wrapper into a jitclass (or structref)?
By default, a jitclass does not appear to be pickleable or cachable.
There might be workarounds but I’m not sure if they work in a cluster (see Github issues below).

import ctypes
import numba as nb
from numba.extending import get_cython_function_address

import dask
import dask.distributed  # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
# dask.config.set({'distributed.worker.multiprocessing-method': 'spawn'})
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
# dask.config.set({'distributed.worker.multiprocessing-method': 'fork'})
# from distributed import Client, LocalCluster
from dask.distributed import Client, LocalCluster

# Can we use a Numba object like jitclass instead?
class ScipyFunctionWrapper:
    def __init__(self):
        self.func = self._get_chuber_function()

    def _get_chuber_function(self):
        """
        Obtain the ctypes function pointer to the Scipy function.
        """
        addr = get_cython_function_address('scipy.special.cython_special', 'huber')
        functype = ctypes.CFUNCTYPE(ctypes.c_double, ctypes.c_double, ctypes.c_double)
        chuber = functype(addr)

        # How to cache?
        @nb.njit  # (cache=True) => ValueError('ctypes objects containing pointers cannot be pickled
        def wrapped_chuber(delta, r):
            return chuber(delta, r)
        return wrapped_chuber

    def __call__(self, delta, r):
        """
        Allows the class instance to be called directly to execute the wrapped function.
        """
        return self.func(delta, r)

    def __reduce__(self):
        """
        Method to define how the object should be pickled.
        """
        return (self.__class__, ())

def test_calc():
    chuber = ScipyFunctionWrapper()
    print(f"Result: {chuber(1.0, 4.0)}")
    # Result: 3.5

def test_pickle():
    import pickle
    chuber = ScipyFunctionWrapper()
    pickled_chuber = pickle.dumps(chuber)
    unpickled_chuber = pickle.loads(pickled_chuber)
    print(f"Result: {unpickled_chuber(1.0, 4.0)}")
    # Result: 3.5

# Test local cluster
if __name__ == "__main__":
    try:
        chuber = ScipyFunctionWrapper()
        with LocalCluster() as cluster:
            with Client(cluster) as client:
                submitted = client.submit(chuber, 1.0, 4.0)
                result = submitted.result()
                print(f"Result: {result}")
                # Result: 3.5
    except Exception as e:
        print("An error occurred:", e)

Here is a way to cache the scipy function call but it will probably not be possible to pickle the function like that.
Maybe both ideas can be combined.

from numba.extending import get_cython_function_address, intrinsic
from numba.core import cgutils
from numba.types import f8, i8
from numba import types, njit

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    '''Recover a function from it's signature and address.'''
    func_type = func_type_ref.instance_type
    def codegen(context, builder, sig, args):
        _, addr = args
        sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
        llty = context.get_value_type(types.voidptr)
        addr_ptr = builder.inttoptr(addr, llty)
        sfunc.addr = addr_ptr
        return sfunc._getvalue()
    sig = func_type(func_type_ref, addr)
    return sig, codegen

fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))

@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, func_addr):
    f = _func_from_address(fn_type, func_addr)
    return f(delta, r)

# Example
print(f'result: {call_chuber(1.0, 4.0, fn_addr)}')

Interesting. Nice that you were able to cache it locally!

I see it does get my workers killed unfortunately. (I also tried setting the address-integer as a constant). I think the reason the cluster is reacting this way, is that fn_type is a global variable (?).

Do you think this is something I should bring up in the Dask discourse instead? I might try that.

MRE on "Cluster of workers"
@njit(f8(f8, f8), cache=True)
def call_chuber(delta, r):
    f = _func_from_address(fn_type, 140228704543888)
    return f(delta, r)

submitted = client.submit(call_chuber, 1.0, 4.0)
submitted.result()

Results in

---------------------------------------------------------------------------
KilledWorker                              Traceback (most recent call last)
Cell In[3], line 1
----> 1 submitted.result()

File /srv/conda/envs/notebook/lib/python3.10/site-packages/distributed/client.py:281, in Future.result(self, timeout)
    279 if self.status == "error":
    280     typ, exc, tb = result
--> 281     raise exc.with_traceback(tb)
    282 elif self.status == "cancelled":
    283     raise result

KilledWorker: Attempted to run task call_chuber_csO-d89e80a3e21dc6bdc9616a565a839a62 on 3 different workers, but all those workers died while running it. The last worker that attempt to run the task was tcp://127.0.0.1:46605. Inspecting worker logs is often a good next step to diagnose what went wrong. For more information see https://distributed.dask.org/en/stable/killed.html.

Also:
When running locally (in a local notebook-session, not on cluster), I notice it is slightly slower than the non-cached version. But I think the approach is still worth looking into. And those Numba-modules are abit outside of my knowledge p.t. (cgutils, intrinsic) but interested in learning them.

Hey @ofk123,

The Dask discourse will definitely be helpful.
Just to clarify, the idea was to utilize a class-like object to handle processes being terminated by using a “reduce” method. This approach would enable the workers to reconstruct if needed.
The Scipy wrapper function needs to be cached, but it’s crucial to separate the construction phase from the evaluation phase otherwise it will be too slow (as it is for now).

from numba import types, njit
from numba.types import f8, i8
from numba.core import cgutils
from numba.extending import get_cython_function_address, intrinsic

import dask
import dask.distributed  # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
from dask.distributed import Client, LocalCluster

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    '''Recover a function from it's signature and address.'''
    func_type = func_type_ref.instance_type

    def codegen(context, builder, sig, args):
        _, addr = args
        sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
        llty = context.get_value_type(types.voidptr)
        addr_ptr = builder.inttoptr(addr, llty)
        sfunc.addr = addr_ptr
        return sfunc._getvalue()
    sig = func_type(func_type_ref, addr)
    return sig, codegen

fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))

# TODO: Can we separate function generation from evaluation?
@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, fn_addr):
    f = _func_from_address(fn_type, fn_addr)
    return f(delta, r)

# TODO: Can we use a Numba object like structref instead?
class ScipyFunctionWrapper:
    def __init__(self):
        self.fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
        self.func = call_chuber

    def __call__(self, delta, r):
        """Allows the class instance to be called directly to execute the wrapped function."""
        return self.func(delta, r, self.fn_addr)

    def __reduce__(self):
        """Method to define how the object should be pickled."""
        return (self.__class__, ())

# Test local cluster
if __name__ == "__main__":
    try:
        chuber = ScipyFunctionWrapper()
        with LocalCluster() as cluster:
            with Client(cluster) as client:
                submitted = client.submit(chuber, 1.0, 4.0)
                result = submitted.result()
                print(f"Result: {result}")
                # Result: 3.5
    except Exception as e:
        print("An error occurred:", e)

Hi @ofk123

Perhaps you can show the actual code that you are running both locally and on the remote site, rather than a (I assume) simplified version of it. The worker logs might also be helpful.

Based on the somewhat limited information you gave, I assume that with the solution suggested by @Oyibo you simply send the function address to the remote workers. This is literally the physical location of that function in your computer’s memory. You would need to retrieve that address on the remote side. This is also the reason why you cannot cache the ctypes pointer. It is highly unsafe.

Hey @sschaer,

I’ve been experimenting with different caching options for external functions using Numba, and I could need some advice.
Numba offers several methods for loading external functions, including get_cython_function_address (via scipy), ctypes, cffi, WrapperAddressProtocol, and ExternalFunction. I’ve noticed during testing that caching fails with get_cython_function_address, ctypes, and cffi, resulting in a warning: “NumbaWarning: Cannot cache compiled function”. However, when using WrapperAddressProtocol and ExternalFunction, caching seems to work as intended without any warnings, and cache files are generated successfully. However, the WrapperAddressProtocol seems to be slower than the other methods.
Is this the preferable approach?
llvmlite.binding.load_library_permanently(find_library(<library>))
followed by
numba.types.ExternalFunction(<funcname>, ...)
What would you recommend in general when handling external functions using Numba and you want to cache them?