Hey @ofk123,
The Dask discourse will definitely be helpful.
Just to clarify, the idea was to utilize a class-like object to handle processes being terminated by using a “reduce” method. This approach would enable the workers to reconstruct if needed.
The Scipy wrapper function needs to be cached, but it’s crucial to separate the construction phase from the evaluation phase otherwise it will be too slow (as it is for now).
from numba import types, njit
from numba.types import f8, i8
from numba.core import cgutils
from numba.extending import get_cython_function_address, intrinsic
import dask
import dask.distributed # populate config with distributed defaults
# dask.config.get('distributed.worker.multiprocessing-method')
dask.config.set({'distributed.worker.multiprocessing-method': 'forkserver'})
from dask.distributed import Client, LocalCluster
@intrinsic
def _func_from_address(typingctx, func_type_ref, addr):
'''Recover a function from it's signature and address.'''
func_type = func_type_ref.instance_type
def codegen(context, builder, sig, args):
_, addr = args
sfunc = cgutils.create_struct_proxy(func_type)(context, builder)
llty = context.get_value_type(types.voidptr)
addr_ptr = builder.inttoptr(addr, llty)
sfunc.addr = addr_ptr
return sfunc._getvalue()
sig = func_type(func_type_ref, addr)
return sig, codegen
fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
fn_type = types.FunctionType(f8(f8, f8))
# TODO: Can we separate function generation from evaluation?
@njit(f8(f8, f8, i8), cache=True)
def call_chuber(delta, r, fn_addr):
f = _func_from_address(fn_type, fn_addr)
return f(delta, r)
# TODO: Can we use a Numba object like structref instead?
class ScipyFunctionWrapper:
def __init__(self):
self.fn_addr = get_cython_function_address('scipy.special.cython_special', 'huber')
self.func = call_chuber
def __call__(self, delta, r):
"""Allows the class instance to be called directly to execute the wrapped function."""
return self.func(delta, r, self.fn_addr)
def __reduce__(self):
"""Method to define how the object should be pickled."""
return (self.__class__, ())
# Test local cluster
if __name__ == "__main__":
try:
chuber = ScipyFunctionWrapper()
with LocalCluster() as cluster:
with Client(cluster) as client:
submitted = client.submit(chuber, 1.0, 4.0)
result = submitted.result()
print(f"Result: {result}")
# Result: 3.5
except Exception as e:
print("An error occurred:", e)