Enhancing jit signature inference using Python type annotations

While Numba’s jit decorator provides a method for explicitly defining function signatures, I’ve been wondering about the possibility of inferring a function signature from python type annotations.
Numpy does not provide the option to define array dimensions, so Numba will not be able to infer the correct dimensions of arrays explicitly.
I’ve noticed that third-party packages like nptyping library already allow for the definition of NumPy array dimensions using Python type annotations, as demonstrated in the following example:

import numpy as np
from typing import Any
from nptyping import NDArray

ArrInt32_1D = NDArray[(Any,), np.int32]
ArrInt32_2D = NDArray[(Any, Any), np.int32]

def foo(arr: ArrInt32_2D) -> ArrInt32_1D:
    return arr.flatten()

This approach simplifies type annotations and makes the code more concise. Now, my question is whether Numba could potentially infer function specifications if numpy.typing provided a method to define array dimensions in annotations?
I’ve observed that jitclass in Numba is already partially capable of such inference based on the provided data for simple types.

1 Like

Would this cause the jitted function semantics to diverge from the plain python behavior?

Hey @nelson2005 ,

let’s start with two simple examples which show the potential beautiful effects (if intended) or horrible effects (if unintended) and continue from there.
There could be functions with several default or omitted arguments.
Only 3 default parameters lead to 27 function variations.

@nb.njit([
    'int64(int64, int64, int64)',
    'int64(int64, int64, Omitted(None))',
    'int64(int64, int64, none)',
    'int64(int64, Omitted(None), int64)',
    'int64(int64, Omitted(None), Omitted(None))',
    'int64(int64, Omitted(None), none)',
    'int64(int64, none, int64)',
    'int64(int64, none, Omitted(None))',
    'int64(int64, none, none)',
    'int64(Omitted(None), int64, int64)',
    'int64(Omitted(None), int64, Omitted(None))',
    'int64(Omitted(None), int64, none)',
    'int64(Omitted(None), Omitted(None), int64)',
    'int64(Omitted(None), Omitted(None), Omitted(None))',
    'int64(Omitted(None), Omitted(None), none)',
    'int64(Omitted(None), none, int64)',
    'int64(Omitted(None), none, Omitted(None))',
    'int64(Omitted(None), none, none)',
    'int64(none, int64, int64)',
    'int64(none, int64, Omitted(None))',
    'int64(none, int64, none)',
    'int64(none, Omitted(None), int64)',
    'int64(none, Omitted(None), Omitted(None))',
    'int64(none, Omitted(None), none)',
    'int64(none, none, int64)',
    'int64(none, none, Omitted(None))',
    'int64(none, none, none)'])
def foo(a: int = None, b: int = None, c: int = None) -> int:
    a = a or 0
    b = b or 0
    c = c or 0
    return a+b+c

Or there could be functions with more complex data types or type containers.
Here the parent data type number contains 10 or even more base types. Again this leads to multiple function variations.

@nb.njit([
    'uint8(uint8)',
    'uint16(uint16)',
    'uint32(uint32)',
    'uint64(uint64)',
    'int8(int8)',
    'int16(int16)',
    'int32(int32)',
    'int64(int64)',
    'float32(float32)',
    'float64(float64)'])
def bar(a: np.number) -> np.number:
    return a+1

These python functions seem to be simple and innocent cases but they generate a huge number of function variations.

Pros:
Defining explicit function signatures for various input combinations can lead to a significant number of function definitions. Inferring types from Python annotations could potentially simplify code by reducing the need for explicit type signature declarations, making the code easier to maintain.
The use case would probably be for ahead of time compiliation or cached code used in a package or library.

Cons:
However, as you pointed out, there are scenarios where inferring types might not be straightforward. For example, when using type containers or multiple default parameters, the generated code by Numba could become complex and may lead to code instability and long compiliation times. It’s essential to consider how automatic type inference would handle such cases.
Another challenge to consider is how (or if) Numba would handle variable-length argument lists (*args) and keyword argument dictionaries (**kwargs) when inferring types.

It’s worth noting that Numba’s jitclass already supports the inference of fields from Python type annotations.
Fields of a jitclass can also be inferred from Python type annotations.
The quote is from the documentation Chapter Compiling Python classes with @jitclass.
This suggests that Numba has the capability to perform type inference from annotations, at least for basic types in certain contexts.
The extension of this functionality to function arguments could be an option depending on the use case.

I thought about that before I posted. To be honest, that surprised me when I realized it back in the day since type annotations are no-ops in the standard python interpreter. I rationalized that in my mind as jitclass is numba-only, it cannot be plain-python or used without numba. :slight_smile:

The more numba semantic deviate from Python semantics the greater risk we run of having numba become a separate language from Python. So far the list is fairly short.
Just something to think about

In a world of high-performance computing and data science, Numba has been a valuable tool for accelerating Python code. However, it’s worth revisiting the assumption made in Numba Enhancement Proposal 5 (NBEP 5) regarding type inference. Back in 2016, it was valid to assume that Python was dynamically typed and relied on type inference due to the absence of user-declared variable types. But times have changed, and so has the Python ecosystem.
I am using Python type annotations all the time (mainly to satisfy the code linters) and on top of that I use explicit signatures whenever I make use of Numba. This is somehow a redundant task but I want my Numba code to be compiled and cached.
The emergence of new programming languages like Mojo, which promise to maintain Python’s syntax while delivering substantial speed-ups (Python++), is a testament to the growing need for performance in Python. Furthermore, the availability of AI tools that can automatically add type annotations to existing or new Python code opens up new possibilities for improving performance effortlessly.
Many Large Language Models are well trained for Python coding. I am not sure how well they perform using Numba specific code. They are probably less effective using such niche products. That means there are obvious advantages regarding the ease of adaptation when it comes to staying close to the Python syntax.
Python has embraced type annotations, introduced in PEP 484 and expanded upon in subsequent PEPs, as a way to provide type hints. These type hints can serve as a source of information about the types of variables, function parameters, and return values. If we consider Python type annotations as another method to infer types in Numba (perhaps as an option in jit functions), we can potentially unlock even more performance gains.
Numba, although it has limited resources, could lead to significant performance improvements without requiring a fundamental change in coding practices. As we strive to write high-performance Python code, maybe it’s worth exploring this opportunity to bridge the gap between dynamic typing and performance, especially in a world where AI tools can assist in the process.

@nelson2005 the list Deviations from Python Semantics is rather short which should be appreciated from my point of view.

# Why not use:
@nb.njit(use_py_anotation=True)
def bar(a: np.integer) -> np.integer:
    return a+1

# to achieve the same like this
@nb.njit([
    'uint8(uint8)',
    'uint16(uint16)',
    'uint32(uint32)',
    'uint64(uint64)',
    'int8(int8)',
    'int16(int16)',
    'int32(int32)',
    'int64(int64)'])
def bar(a: np.integer) -> np.integer:
    return a+1
1 Like