Nested numba functions slow down execution and cannot be inlined

We are using Numba to speed up a rather complex code base. While this works well in general we noticed the following undesirable behavior:

When calling a nested numba function it executes much slower (up to 3x) as code with the same functionality where we manually “inlined” that is copied the content of the function into the outer function. Setting njit(inline=’always’) does recover some performance but is still not as fast as the code without a function.

Unfortunately, we have not been able to reproduce the behavior in a toy example. But the idea is as follows:

@njit(nogil=True,inline=’always’)
def inner_function(a,b,c):
    <inner function code>

@njit
def outer_function():
    for _ in range(100):
        inner_function(a,b,c)

@nijit
def monolith_function():
    for _ in range(100):
        <inner function code>

The monolith_function will run much faster than outer_function.

Has anyone observed similar behavior? What could be factors in the larger code base that might affect performance in such a way?

We have tried numba versions 0.51 0.55 and 0.56.4 as well as python 3.9 and 3.10

1 Like

Hi @ckk,

I wonder if this is related to issues with when optimisation occurs and what it targets as described in this discussion: Compilation pipeline, compile time and vectorization. Also, the inline="always", functionality will likely lead to the use of additional variables etc to “wire up” the inlined call site, it may be that there’s something in there which prevents an optimisation occurring (e.g. the complexity gets too high).

Hope this helps.

1 Like

@stuartarchibald Thanks the discussion does indeed sound related. I am not sure if our case is caused by the vectorisation but the order of optimisation passes could be the culprit.

I did find another related post on StackOverflow which also has a toy example outlining the issue.
on my machine Apple M1 Pro the nested function takes 979.00 us vs the non nested 87.00 us

import time

import numba as nb


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_4(a, b):
    x = a ^ b
    setBits = 0
    while x > 0:
        setBits += x & 1
        x >>= 1
    return setBits


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_3(c, set_1, set_2):
    h = 2
    if c not in set_1 and c not in set_2:
        if fct_4(0, c) <= h:
            set_1.add(c)
        else:
            set_2.add(c)


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_2(c, set_1, set_2):
    fct_3(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1_nested(set_1, set_2):
    for x1 in range(100000):
        c = 2
        fct_2(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1(set_1, set_2):
    for x1 in range(100000):
        c = 2
        h = 2
        if c not in set_1 and c not in set_2:
            if fct_4(0, c) <= h:
                set_1.add(c)
            else:
                set_2.add(c)


if __name__ == "__main__":
    s1 = set(range(10))
    s2 = set(range(1000))
    fct_1(s1, s2)
    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_nested(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")```

I just ran that example on the branch with only global optimization mentioned here, but it seems that only has a minor impact:

%timeit fct_1_nested(s1, s2)
%timeit fct_1(s1, s2)

# Default optimizations in numba
2.52 ms ± 11.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
116 µs ± 376 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

# merged-compile branch
2.32 ms ± 18.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
120 µs ± 925 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

The merged compilation is a little bit faster, but it doesn’t look like it is the main reason for the performance difference.

Also, the inline="always", functionality will likely lead to the use of additional variables etc to “wire up” the inlined call site, it may be that there’s something in there which prevents an optimisation occurring (e.g. the complexity gets too high).

@stuartarchibald seems to be on to somehting.
I have extended the example with a function “fct_1_redundant” which is more or less equivalent to what the inlining pass results in.

import time

import numba as nb
import warnings

warnings.filterwarnings("ignore")


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_4(a, b):
    x = a ^ b
    setBits = 0
    while x > 0:
        setBits += x & 1
        x >>= 1
    return setBits


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_3(c, set_1, set_2):
    h = 2
    if c not in set_1 and c not in set_2:
        if fct_4(0, c) <= h:
            set_1.add(c)
        else:
            set_2.add(c)


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_2(c, set_1, set_2):
    fct_3(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1_nested(set_1, set_2):
    for x1 in range(100000):
        c = 2
        fct_2(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1(set_1, set_2):
    for x1 in range(100000):
        c = 2
        h = 2
        if c not in set_1 and c not in set_2:
            if fct_4(0, c) <= h:
                set_1.add(c)
            else:
                set_2.add(c)
                
@nb.njit(cache=False)
def fct_1_redundant(set_1, set_2):
    for x1 in range(100000):
        c = 2
        
        fct_2_c = c
        fct_2_set_1 = set_1
        fct_2_set_2 = set_2
        
        fct_3_c = fct_2_c
        fct_3_set_1 = fct_2_set_1
        fct_3_set_2 = fct_2_set_2      
        
        fct_3_h = 2
        if fct_3_c not in fct_3_set_1 and fct_3_c not in fct_3_set_2:
            if fct_4(0, fct_3_c) <= fct_3_h:
                fct_3_set_1.add(fct_3_c)
            else:
                fct_3_set_2.add(fct_3_c)                     


if __name__ == "__main__":
    s1 = set(range(10))
    s2 = set(range(1000))
    fct_1(s1, s2)
    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_nested(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")
        
    fct_1_redundant(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_redundant(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

Interestingly, the new function is just as slow as the nested version!
Consequently, the additional variable assignments seem to be the main reason for the large performance difference.

Unfortunately, this seems to be non-trivial to resolve.
We could implement a compiler pass that does not generate additional assignments during inlining.
However, this would be problematic if the value of an argument of the inlined function is changed in the function body.
Detecting whether the value of a variable is changed and then deciding whether an additional assignment is neccessary also doesn’t seem straightforward to me.

1 Like

Further investigation revealed that only the additional variables for the two sets result in increased runtime.
No runtime increase can be oberserved if we only create the additional variables for the integer c.

In the extended example below, the function fct_1_redundant_only_one_set only has additional variables for one of the two sets. This results in a little bit more than half the runtime.
The function fct_1_redundant_no_sets only has additional variables for the integer c.
It has the same runtime as the original function despite the additional variable assignments.

So this issue is likely only problematic for more complex objects passed into the inlined function (pointers?).

import time

import numba as nb
import warnings

warnings.filterwarnings("ignore")


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_4(a, b):
    x = a ^ b
    setBits = 0
    while x > 0:
        setBits += x & 1
        x >>= 1
    return setBits


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_3(c, set_1, set_2):
    h = 2
    if c not in set_1 and c not in set_2:
        if fct_4(0, c) <= h:
            set_1.add(c)
        else:
            set_2.add(c)


@nb.njit(cache=False, no_cpython_wrapper=True)
def fct_2(c, set_1, set_2):
    fct_3(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1_nested(set_1, set_2):
    for x1 in range(100000):
        c = 2
        fct_2(c, set_1, set_2)


@nb.njit(cache=False)
def fct_1(set_1, set_2):
    for x1 in range(100000):
        c = 2
        h = 2
        if c not in set_1 and c not in set_2:
            if fct_4(0, c) <= h:
                set_1.add(c)
            else:
                set_2.add(c)


@nb.njit(cache=False)
def fct_1_redundant(set_1, set_2):
    for x1 in range(100000):
        c = 2

        fct_2_c = c
        fct_2_set_1 = set_1
        fct_2_set_2 = set_2

        fct_3_c = fct_2_c
        fct_3_set_1 = fct_2_set_1
        fct_3_set_2 = fct_2_set_2

        fct_3_h = 2
        if fct_3_c not in fct_3_set_1 and fct_3_c not in fct_3_set_2:
            if fct_4(0, fct_3_c) <= fct_3_h:
                fct_3_set_1.add(fct_3_c)
            else:
                fct_3_set_2.add(fct_3_c)


@nb.njit(cache=False)
def fct_1_redundant_only_one_set(set_1, set_2):
    for x1 in range(100000):
        c = 2

        fct_2_c = c
        fct_2_set_1 = set_1

        fct_3_c = fct_2_c
        fct_3_set_1 = fct_2_set_1

        fct_3_h = 2
        if fct_3_c not in fct_3_set_1 and fct_3_c not in set_2:
            if fct_4(0, fct_3_c) <= fct_3_h:
                fct_3_set_1.add(fct_3_c)
            else:
                set_2.add(fct_3_c)


@nb.njit(cache=False)
def fct_1_redundant_no_sets(set_1, set_2):
    for x1 in range(100000):
        c = 2

        fct_2_c = c

        fct_3_c = fct_2_c

        fct_3_h = 2
        if fct_3_c not in set_1 and fct_3_c not in set_2:
            if fct_4(0, fct_3_c) <= fct_3_h:
                set_1.add(fct_3_c)
            else:
                set_2.add(fct_3_c)


if __name__ == "__main__":
    s1 = set(range(10))
    s2 = set(range(1000))
    fct_1(s1, s2)
    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_nested(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_nested(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_redundant(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_redundant(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_redundant_only_one_set(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_redundant_only_one_set(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

    fct_1_redundant_no_sets(s1, s2)
    for _ in range(6):
        start = time.process_time_ns()
        fct_1_redundant_no_sets(s1, s2)
        print(f"{(time.process_time_ns() - start) / 1000:.2f} us")

1 Like

Interestingly, the LLVM assembly code (after optimzation) is not all that different between fct_1 and fct_1_redundant.

Only two blocks are significantly different.
For fct_1:

B77:                                              ; preds = %lookup.body.if, %lookup.body.1.if, %for.body.if, %for.body.if.1, %for.body.if.2, %for.body.1.if, %for.body.1.if.1, %for.body.1.if.2, %lookup.end.2.endif
  %.95 = icmp ugt i64 %.40.0128, 1
  br i1 %.95, label %B10.if, label %B76

B10.if:                                           ; preds = %entry, %B77
  %.40.0128 = phi i64 [ 1000, %entry ], [ %.104, %B77 ]
  %.104 = add nsw i64 %.40.0128, -1
  %sunkaddr181 = getelementptr i8, i8* %arg.set_1.0, i64 24
  %1 = bitcast i8* %sunkaddr181 to { i64, i64, i64, i64, i8, { i64, i64 } }**
  %.6.i94 = load { i64, i64, i64, i64, i8, { i64, i64 } }*, { i64, i64, i64, i64, i8, { i64, i64 } }** %1, align 8
  %.149 = getelementptr inbounds { i64, i64, i64, i64, i8, { i64, i64 } }, { i64, i64, i64, i64, i8, { i64, i64 } }* %.6.i94, i64 0, i32 5
  %.176 = getelementptr inbounds { i64, i64, i64, i64, i8, { i64, i64 } }, { i64, i64, i64, i64, i8, { i64, i64 } }* %.6.i94, i64 0, i32 2
  %.177 = load i64, i64* %.176, align 8
  %.181 = and i64 %.177, 2
  %.190 = getelementptr { i64, i64 }, { i64, i64 }* %.149, i64 %.181, i32 0
  %.191 = load i64, i64* %.190, align 8
  switch i64 %.191, label %for.body.endif.endif [
    i64 2, label %for.body.if
    i64 -1, label %B30
  ]

vs. for fct_1_redundant:

B101:                                             ; preds = %lookup.body.if, %lookup.body.1.if, %for.body.if, %for.body.if.1, %for.body.if.2, %for.body.1.if, %for.body.1.if.1, %for.body.1.if.2, %lookup.end.2.endif
  %fct_3_set_1.sroa.0.1 = phi i8* [ null, %lookup.end.2.endif ], [ %arg.set_1.0, %for.body.1.if.2 ], [ %arg.set_1.0, %for.body.1.if.1 ], [ %arg.set_1.0, %for.body.1.if ], [ %arg.set_1.0, %for.body.if.2 ], [ %arg.set_1.0, %for.body.if.1 ], [ %arg.set_1.0, %for.body.if ], [ %arg.set_1.0, %lookup.body.1.if ], [ %arg.set_1.0, %lookup.body.if ]
  %fct_3_set_2.sroa.0.1 = phi i8* [ null, %lookup.end.2.endif ], [ %arg.set_2.0, %for.body.1.if.2 ], [ %arg.set_2.0, %for.body.1.if.1 ], [ %arg.set_2.0, %for.body.1.if ], [ %arg.set_2.0, %for.body.if.2 ], [ %arg.set_2.0, %for.body.if.1 ], [ %arg.set_2.0, %for.body.if ], [ %arg.set_2.0, %lookup.body.1.if ], [ %arg.set_2.0, %lookup.body.if ]
  tail call void @NRT_decref(i8* %fct_3_set_2.sroa.0.1)
  tail call void @NRT_decref(i8* %fct_3_set_1.sroa.0.1)
  %.95 = icmp ugt i64 %.40.0128, 1
  br i1 %.95, label %B10.if, label %B100

B10.if:                                           ; preds = %entry, %B101
  %.40.0128 = phi i64 [ 1000, %entry ], [ %.104, %B101 ]
  %.104 = add nsw i64 %.40.0128, -1
  tail call void @NRT_incref(i8* %arg.set_1.0)
  tail call void @NRT_incref(i8* %arg.set_2.0)
  %sunkaddr181 = getelementptr i8, i8* %arg.set_1.0, i64 24
  %1 = bitcast i8* %sunkaddr181 to { i64, i64, i64, i64, i8, { i64, i64 } }**
  %.6.i94 = load { i64, i64, i64, i64, i8, { i64, i64 } }*, { i64, i64, i64, i64, i8, { i64, i64 } }** %1, align 8
  %.165 = getelementptr inbounds { i64, i64, i64, i64, i8, { i64, i64 } }, { i64, i64, i64, i64, i8, { i64, i64 } }* %.6.i94, i64 0, i32 5
  %.192 = getelementptr inbounds { i64, i64, i64, i64, i8, { i64, i64 } }, { i64, i64, i64, i64, i8, { i64, i64 } }* %.6.i94, i64 0, i32 2
  %.193 = load i64, i64* %.192, align 8
  %.197 = and i64 %.193, 2
  %.206 = getelementptr { i64, i64 }, { i64, i64 }* %.165, i64 %.197, i32 0
  %.207 = load i64, i64* %.206, align 8
  switch i64 %.207, label %for.body.endif.endif [
    i64 2, label %for.body.if
    i64 -1, label %B54
  ]

The LLVM assembly code contains two additional NRT_incref and NRT_decref.
Furthermore, it contains the phi Expressions for %fct_3_set_1.sroa.0.1 and %fct_3_set_2.sroa.0.1.
Interestingly, neither %fct_3_set_1.sroa.0.1 nor %fct_3_set_2.sroa.0.1 are used anywhere in the LLVM assembly code except for the NRT_decref directly afterwards. Furthermore, there is no matching NRT_incref.

There might be something weird going on?

1 Like