Unexpected behaviour with literally

Going through the demo of the new features in 0.51 I find a curious example

from numba import jit, njit, config, __version__, errors, literal_unroll, types
from numba.extending import overload
import numba
import numpy as np
assert tuple(int(x) for x in __version__.split('.')[:2]) >= (0, 51)

def demo_iv(x):
    pass

@overload(demo_iv)
def ol_demo_iv(x):
    # if the initial_value is not present, request literal value dispatch
    if x.initial_value is None:
        return lambda x: literally(x)
    else: # initial_value is present on the type
        print("type of x: {}. Initial value {}".format(x, x.initial_value))
        return lambda x: 1

@njit
def initial_value_capturing():
    l = [1, 2, 3, 4] # initial value [1, 2, 3, 4]
    l.append(5) # not part of the initial value
    print("out=", demo_iv(l))
    
initial_value_capturing() # out=1

versus (now removing literally)

def demo_iv(x):
    pass

@overload(demo_iv)
def ol_demo_iv(x):
    # if the initial_value is not present, request literal value dispatch
    if x.initial_value is None:
        return lambda x: 2#literally(x)
    else: # initial_value is present on the type
        print("type of x: {}. Initial value {}".format(x, x.initial_value))
        return lambda x: 1

@njit
def initial_value_capturing():
    l = [1, 2, 3, 4] # initial value [1, 2, 3, 4]
    l.append(5) # not part of the initial value
    print("out=", demo_iv(l))
    
initial_value_capturing() # out=2

In the first example, the code went through the second branch (therefore producing out = 1). However, removing a function from the first branch, makes the code go through this branch (producing out = 2). That’s unexpected.

Also the changes happened at different levels. The branches are part of the typing, however the function is part of the function content, ie it was not executed during typing. So the typing is reading inside the content of the function, and changing its behaviour based on that. I know that this is what literally is supposed to do, but it’s hard to grok.

To make it more confusing, literally has not been imported. But it just works?

I would have found the behaviour more intuitive if literally was placed in the caller, to clearly indicate that the argument is to be interpreted a literal.

@njit
def initial_value_capturing():
    l = [1, 2, 3, 4] # initial value [1, 2, 3, 4]
    l.append(5) # not part of the initial value
    print("out", demo_iv(literally(l)))

I think this is how it’s used in the documentation (https://numba.pydata.org/numba-doc/latest/developer/literal.html?highlight=literally)

I don’t know if this is how it’s supposed to work and I need to understand the internal logic, or whether this is something unintended.

Hi @luk-f-a,

First, thanks for taking a look at the demo notebook for 0.51 and providing feedback, much appreciated.

As mentioned in the CHANGE_LOG and demo notebook, in Numba 0.51 the compiler has been altered to try compilation with non-literal (less specialised) types first. The reason for this change was to try and speed up compilation as literal values would often trigger pointless over-specialisation (example coming up). I think the example above hits some of the implications of this change so I’ll try and explain them…

To start with, as a small example, take this code:

from numba import njit
from numba.extending import overload

def bar(x):
    pass

@overload(bar)
def ol_bar(x):
    # Prints if the overload is resolved and the type it resolved for
    print(x)
    return lambda x: x

@njit
def foo():
    bar(1)
    bar(2)

foo()

In 0.50 it produces:

Literal[int](1)
Literal[int](2)

In 0.51 it produces:

int64

As demonstrated, in 0.51 there’s only one resolution of bar and it is for the int64 type. This has reduced the specialisation in comparison to what 0.50 did where two versions of bar existed, one for the value 1 and one for the value 2! In 0.51 it’s still possible to specialise on the literal values by using literally, it’s just not default.

To the example code in the OP, what’s happening is that when the compiler tries to resolve the function demo_iv the non-literal version of the list is tried first, this triggers the if x.initial_value is None branch, there’s then TypingError raised as a result of literally being an undefined global. This causes the stack to unwind and for the compiler to retry resolving demo_iv with literal versions of the arguments, in this case the initial_values get captured as part of the literal specialisation of the list and as a result the second branch is hit, hence out=1.

Also the changes happened at different levels. The branches are part of the typing, however the function is part of the function content, ie it was not executed during typing. So the typing is reading inside the content of the function, and changing its behaviour based on that. I know that this is what literally is supposed to do, but it’s hard to grok.

The branches in the overload are part of the typing logic, however, Numba has to be able to complete type inference on the implementation to be able to determine the return type of the function for use in the callee (in the print("out="... in this case). Therefore the function is “executed” to the extent that it must be typable (survives type-inference). As noted previously, the fact that the first branch contains an untypable function results in a TypingError which unwinds the stack to retry with literal types. An equivalent behaviour could also be triggered by e.g. return lambda x: object() or just raising a numba.core.errors.TypingError in the block.

To make it more confusing, literally has not been imported. But it just works?

Hopefully the above has explained that, if the notebook doesn’t contain the import it’s an invalid example and needs patching.

I would have found the behaviour more intuitive if literally was placed in the caller, to clearly indicate that the argument is to be interpreted a literal.

This is a good point, the counter argument is that if you were e.g. a library author trying to replicate an already existing API with numba.jit compiling the implementation, that an argument needs to be a literal value is a detail of the implementation that should not be the concern of the user (who could reasonably expect the jitted API to match the non-jitted and not want/need to wrap arguments in literally).

I don’t know if this is how it’s supposed to work and I need to understand the internal logic, or whether this is something unintended.

These are advanced and complicated features that are right at the forefront of what can be done in Numba’s current type inference mechanism. It also highlights how the current type inference mechanism as been really stretched and is close to its limit. In writing features such as these it has helped solidify ideas about what would need to be accommodated in a more advanced type inference mechanism.

Hope this helps?

hi Stuart, thanks a lot for the explanation! It’s all much clearer now.

Regarding

" library author trying to replicate an already existing API with numba.jit compiling the implementation, that an argument needs to be a literal value is a detail of the implementation that should not be the concern of the user",

I fully agree that it’s more important for the receiving function to decide. I didn’t realize that when I wrote the above.

Maybe in future versions of the type inference, the “literally signal” could be done via some kind of partial typing in the signature:

@njit("x:Literal")
def foo(x, y):
    ....

This syntax could apply to all subtyping relations in general. But that’s a topic for another day :slight_smile:

Luk