Using closure to reuse code

Hi, I’m trying to compose njit-ed functions that use closure values.
Is there any way to make this pattern work, while keeping h defined outside f?

@numba.njit
def h(g):
    return g(1)
@numba.njit
def f(a,b):
    def g(i):
        return b[i]
    return h(g)
f(1,np.zeros(3))
TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'b' in a function that will escape.

File "<ipython-input-179-3da212f9c5bf", line 6:
def f(a,b):
    def g(i):
    ^

‘b’ here is not mutated thorough the algorithm, and ‘g’ could be inlined.

I arrived at an initial solution using @jitclass to emulate functools.partial.
I tried to improve the interface with dynamic class generation and the use of @generated_jit.
Example below:

import pyrsistent
from numba.experimental import jitclass

from REDACTED.util.indented_fstring import deindent
from weakref import WeakValueDictionary
_numba_closure_cache=WeakValueDictionary()
def numba_closure_func(name,func,const={},variable=()):
    global _numba_closure_cache
    key = pyrsistent.freeze((func,const,variable))
    if key in _numba_closure_func_cache:
        return _numba_closure_func_cache[key]
    spec = [(k,v) for k,v in const.items()]
    
    call_variable_args = ','+','.join(variable) if variable else ''
    
    func_variable_args = ','+','.join(variable) if variable else ''
    func_args = ','.join(f"self.{k}" for k in const.keys())+func_variable_args
    cls_def =deindent(f"""\
        class {name}:
            def __init__(self{','+','.join(const.keys()) if const else ''}):
                {';'.join(f'self.{k}={k}' for k in const.keys()) if const else 'pass'}
            def call(self{call_variable_args}):
                return func({func_args})
        """)
    print(spec)
    print(cls_def)
    lg = dict(func=func)
    ld = dict()
    exec(cls_def,lg,ld)
    kls = ld[name]
    klsj = jitclass(spec)(kls)
    _numba_closure_cache[key]=klsj
    return klsj

import numpy as np

@numba.njit(inline='always')
def G_func(b,i):
    return b[i]
    
@numba.njit(inline='always')
def h(g):
    return g.call(1)

@numba.njit(inline='always')
def g_func(b,i):
    return b[i]

@numba.generated_jit(nopython=True)
def f(a,b):
    G = numba_closure_func('G',g_func,const={'b':b},variable=('i',))
    def _f(a,b):
        g = G(b)
        return h(g)
    return _f
    
f(1,np.zeros(3))
>>>#output
#print of spec and class def
[('b', array(float64, 1d, C))]
class G:
    def __init__(self,b):
        self.b=b
    def call(self,i):
        return func(self.b,i)
# correct result
0.0

hi there! was this not an option for you

@numba.njit
    def h(g, b):
        return g(1, b)
    @numba.njit
    def f(a,b):
        def g(i, b):
            return b[i]
        return h(g, b)
    f(1,np.zeros(3))

That would carry g’s signature into h’s, making h less general; I would have to have a h written for each possible signature (signature as in number of arguments) of g.