Does numba support returning closures from a @jitted function?

import numba as nb

@nb.njit
def foo(x):
    def bar(y):
        return x + y
    return bar

foo(3)(4)
TypingError: Failed in nopython mode pipeline (step: convert make_function into JIT functions)
Cannot capture the non-constant value associated with variable 'x' in a function that will escape.

I suspect the capture mention is numba’s way to tell me it doesn’t support compiling functions defined in jit-ed code.

hi @gael , you’re right, this is not supported. What is specifically not supported is the combination of outer function being jitted, returning the inner function , and inner function using a closure variable. If you remove any of the 3, it does work.

No closure variable

import numba as nb
@nb.njit
def foo(x):
    def bar(y):
        return y
    return bar
foo(3)(4)
Out[3]: 4

Outer function is not jitted, but inner function is still jitted

import numba as nb
def foo(x):
    @nb.njit
    def bar(y):
        return y + x
    return bar
foo(3)(4)
Out[4]: 7

Inner function not returning

import numba as nb
@nb.njit
def foo(x):
    
    def bar(y):
        return y + x
    
    return bar(4)
foo(3)
Out[6]: 7

Maybe one of those 3 options could still help you?

Luk

I mean in principle those suggestions are perfectly sensible and in terms of code running to completion, they make sense.

I just find very frustrating I can get bar work if I write things out completely:

@jit
def baz(x, y):

    # CUT CUT CUT
    def bar(y):
        # Do something with x and y
        # potentially very long function
        # that I might want to reuse somewhere else
        return x + y
    # CUT CUT CUT

    s = 0
    for i in range(10):
        s += bar(y)
    return s

but not if I want to extract it outside… even when forcing numba-side inline…

@jit(inline="always")
def foo(x):

    def bar(y):
        # Do something with x and y
        # potentially very long function
        # that I might want to reuse somewhere else
        return x + y

    return bar

@jit
def baz(x, y):
    # CUT CUT CUT
    bar = foo(x)
    # CUT CUT CUT

    s = 0
    for i in range(10):
        s += bar(y)
    return s

With anything other than scalar types, LLVM cannot assume x is gonna be constant throughout the loop and therefore won’t generate vectorised code (well maybe in reduction code like this). In the first instance, I don’t believe bar will specialise over x anyway (will just be inlined). The second case would force specialisation or at least might be enough from LLVM to understand x is constant…

The issue with “returning” closures that capture an argument of the outer function is that it essentially requires value dependent specialisation, whereas Numba only specialises on type.

In your two examples in Does numba support returning closures from a @jitted function? - #3 by gael, the first is a non-escaping closure, so Numba will just inline the Numba IR for the closure at the call site. In the second, and the example in the OP, the returned closure has to capture the runtime value of an argument to the outer function. When Numba compiles it does so based on types, not values, so it cannot differentiate between calling the OP foo(1) and foo(2) and so it cannot return bar as it has a value based dependency.

That makes sense. Value-based specialisation was exactly what I was after. :sweat_smile:

In one application I do value specialization but I handle it outside Numba. Combining

def foo(x):
    @nb.njit
    def bar(y):
        return y + x
    return bar

A python dictionary all_bars[x] = foo(x) gets you value specialization, even if it’s partially outside jitted code.

@stuartarchibald , in theory generated_jit plus literal types could produce a form of value specialization, right? At least for integers and strings.
@gael , the danger of value specialization is that you have to make sure that the number of distinct values remains low or otherwise compilation time would explode.

Luk

A python dictionary all_bars[x] = foo(x) gets you value specialization, even if it’s partially outside jitted code

This is what I have been doing so far as a workaround. Not optimal as it requires either to pass arguments that matter™ to an init function in Pythonland to generate the right specialisations or you have to hardcode those.

To give you an idea of why hardcoding is not appealing to me, if interested... click here

I’m writing a small library to perform a specific analysis on density maps. Those maps can be of any dimension (1D, 2D, 3D, 100D, etc.) however for any specific application, people will use a single dimension (or very few of them). Like, I use mainly 3D and, for a very weird reason, 11D. So the number of specialisations is not expected to explode but having to hardcode them prior to using them is a major thorn on my side.

All is good when I use the density map itself: ndim is enough and can be passed around as a Literal. However some functions, used in other contexts, expect an array of shape (N, ndim) or (N, ndim, ndim) and so far I have resorted to pass around ndim as a Literal or the original density map itself (but just for the type information which looks a bit nuts).


in theory generated_jit plus literal types could produce a form of value specialization, right?

If you use literally, yes because any unspecialised implementation would raise an error, forcing the compiler to become more specific.

otherwise compilation time would explode.

I am very aware of that. :+1: (see justification above)


Now I was just thinking, is there a way to force specific unboxing/boxing processes when jitting besides wrapping my native python types into specialised ones?

I saw the Array subclassing tutorial in numba-examples. Would be neat to be able to tell jit to unbox an input array into MyArrayType and to box it into a normal array once it gets back to Pythonland (or continue using MyArrayType as long as it stays in Numbaland).

I believe this would provide just enough flexibility without making it worse for everyone by default. (talking about arrays but could be applied to anything, now that it’s documented)

Assuming np.ndarrays the ndim is part of the type, can you just use that to dispatch or does the implementation to dispatch to also depend on the runtime value of N?

This sort of thing might help?

from numba import njit
import numpy as np
from numba.extending import overload

def ndim_dispatcher(arr):
    pass

@overload(ndim_dispatcher)
def ol_ndim_dispatcher(arr):
    if arr.ndim == 3:
        def impl(arr):
            return "This is 3", arr.ndim
    elif arr.ndim == 2:
        def impl(arr):
            return "This is 2", arr.ndim
    elif arr.ndim == 1:
        def impl(arr):
            return "This is 1", arr.ndim
    else:
        def impl(arr):
            return "This is unspecialised", arr.ndim
    return impl

@njit
def foo(x):
    print(ndim_dispatcher(x))

for i in range(6, 0, -1):
    arr = np.ones((1,) * i)
    foo(arr)

Yes indeed! This rather contrived example demonstrates (but with @overload not @generated_jit, the idea is essentially the same): https://github.com/numba/numba/blob/58426cf7658177b792b820ff66c0945985b72da0/numba/tests/test_mixed_tuple_unroller.py#L864-L911

Further, Numba can specialise on the “initial” value of some containers, for example string keyed heterogeneous dicts (there’s more tests around this one for further examples): https://github.com/numba/numba/blob/58426cf7658177b792b820ff66c0945985b72da0/numba/tests/test_dictobject.py#L1933-L1956
and a really extreme example mixing a load of literal value concepts here:
https://github.com/numba/numba/blob/58426cf7658177b792b820ff66c0945985b72da0/numba/tests/test_lists.py#L1560-L1618

It should be possible (though potentially complicated) to enhance the compiler pass that deals with escaping closures to handle cases like:

from numba import njit, types, literally
from numba.extending import overload

def dispatcher(x):
    pass

@overload(dispatcher)
def ol_dispatcher(x):
    if not isinstance(x, types.IntegerLiteral):
        def impl(x):
            literally(x)
    else:
        literal_x = x.literal_value
        def impl(x):
            def bar(y):
                return y + literal_x
            return bar
    return impl


@njit
def foo():
    one_impl = dispatcher(1)
    two_impl = dispatcher(2)

    return one_impl, two_impl

one, two = foo()
print(one(10))
print(two(10))

The issue at present is that when converting the code object that is the closure into a python function that Numba can @njit, the freevar literal_x cannot be found because there’s no code path to scan the freevars of the outer function. If this is something that would be useful to anyone, please do open a feature request at: Issues · numba/numba · GitHub. Thanks.