Hi, I’m trying to compose njit-ed functions that use closure values.
Is there any way to make this pattern work, while keeping h defined outside f?
@numba.njit
def h(g):
return g(1)
@numba.njit
def f(a,b):
def g(i):
return b[i]
return h(g)
f(1,np.zeros(3))
TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'b' in a function that will escape.
File "<ipython-input-179-3da212f9c5bf", line 6:
def f(a,b):
def g(i):
^
‘b’ here is not mutated thorough the algorithm, and ‘g’ could be inlined.
I arrived at an initial solution using @jitclass
to emulate functools.partial
.
I tried to improve the interface with dynamic class generation and the use of @generated_jit
.
Example below:
import pyrsistent
from numba.experimental import jitclass
from REDACTED.util.indented_fstring import deindent
from weakref import WeakValueDictionary
_numba_closure_cache=WeakValueDictionary()
def numba_closure_func(name,func,const={},variable=()):
global _numba_closure_cache
key = pyrsistent.freeze((func,const,variable))
if key in _numba_closure_func_cache:
return _numba_closure_func_cache[key]
spec = [(k,v) for k,v in const.items()]
call_variable_args = ','+','.join(variable) if variable else ''
func_variable_args = ','+','.join(variable) if variable else ''
func_args = ','.join(f"self.{k}" for k in const.keys())+func_variable_args
cls_def =deindent(f"""\
class {name}:
def __init__(self{','+','.join(const.keys()) if const else ''}):
{';'.join(f'self.{k}={k}' for k in const.keys()) if const else 'pass'}
def call(self{call_variable_args}):
return func({func_args})
""")
print(spec)
print(cls_def)
lg = dict(func=func)
ld = dict()
exec(cls_def,lg,ld)
kls = ld[name]
klsj = jitclass(spec)(kls)
_numba_closure_cache[key]=klsj
return klsj
import numpy as np
@numba.njit(inline='always')
def G_func(b,i):
return b[i]
@numba.njit(inline='always')
def h(g):
return g.call(1)
@numba.njit(inline='always')
def g_func(b,i):
return b[i]
@numba.generated_jit(nopython=True)
def f(a,b):
G = numba_closure_func('G',g_func,const={'b':b},variable=('i',))
def _f(a,b):
g = G(b)
return h(g)
return _f
f(1,np.zeros(3))
>>>#output
#print of spec and class def
[('b', array(float64, 1d, C))]
class G:
def __init__(self,b):
self.b=b
def call(self,i):
return func(self.b,i)
# correct result
0.0
hi there! was this not an option for you
@numba.njit
def h(g, b):
return g(1, b)
@numba.njit
def f(a,b):
def g(i, b):
return b[i]
return h(g, b)
f(1,np.zeros(3))
That would carry g’s signature into h’s, making h less general; I would have to have a h written for each possible signature (signature as in number of arguments) of g.