How can I cache nopython-mode scipy.special.cython_special-functions?

Hey @ofk123,

The Dask discourse will definitely be helpful.
Just to clarify, the idea was to utilize a class-like object to handle processes being terminated by using a “reduce” method. This approach would enable the workers to reconstruct if needed.
The Scipy wrapper function needs to be cached, but it’s crucial to separate the construction phase from the evaluation phase otherwise it will be too slow (as it is for now).

from numba import types, njit
from numba.types import f8, i8
from numba.core import cgutils
from numba.extending import get_cython_function_address, intrinsic

import dask
import dask.distributed  # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
from dask.distributed import Client, LocalCluster

@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
    '''Recover a function from it's signature and address.'''
    func_type = func_type_ref.instance_type

    def codegen(context, builder, sig, args):
        _, addr = args
        sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
        llty = context.get_value_type(types.voidptr)
        addr_ptr = builder.inttoptr(addr, llty)
        sfunc.addr = addr_ptr
        return sfunc._getvalue()
    sig = func_type(func_type_ref, addr)
    return sig, codegen

fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))

# TODO: Can we separate function generation from evaluation?
@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, fn_addr):
    f = _func_from_address(fn_type, fn_addr)
    return f(delta, r)

# TODO: Can we use a Numba object like structref instead?
class ScipyFunctionWrapper:
    def __init__(self):
        self.fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
        self.func = call_chuber

    def __call__(self, delta, r):
        """Allows the class instance to be called directly to execute the wrapped function."""
        return self.func(delta, r, self.fn_addr)

    def __reduce__(self):
        """Method to define how the object should be pickled."""
        return (self.__class__, ())

# Test local cluster
if __name__ == "__main__":
    try:
        chuber = ScipyFunctionWrapper()
        with LocalCluster() as cluster:
            with Client(cluster) as client:
                submitted = client.submit(chuber, 1.0, 4.0)
                result = submitted.result()
                print(f"Result: {result}")
                # Result: 3.5
    except Exception as e:
        print("An error occurred:", e)