Performance discrepancy between jit/generated_jit/overload

Hi! I am busy trying to compose a relatively complicated piece of numba code using generated_jit/overload. However, whilst moving stuff to generated jit I noticed that performance was getting steadily worse. I then coded up the following example which demonstrates a fairly large discrepancy between jitted/generated_jit/overload code. I am not quite sure if this qualifies as a bug and wanted to be sure that I am not doing something obviously wrong. The example is as follows:

from numba.extending import overload
from numba import jit, types, generated_jit, literally
import numpy as np
import time


def dot2x2_impl(v1, v2, md):
    v100, v101, v110, v111 = v1[0], v1[1], v1[2], v1[3]
    v200, v201, v210, v211 = v2[0], v2[1], v2[2], v2[3]

    v300 = (v100*v200 + v101*v210)
    v301 = (v100*v201 + v101*v211)
    v310 = (v110*v200 + v111*v210)
    v311 = (v110*v201 + v111*v211)

    return v300, v301, v310, v311


@generated_jit(nogil=True, nopython=True, cache=True, inline="always")
def gen_f(v1, v2, md):

    if not isinstance(md, types.Literal):
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        return dot2x2_impl


@jit(nogil=True, nopython=True, cache=True)
def ovr_f(v1, v2, md):

    return __v1_mul_v2(v1, v2, md)


def __v1_mul_v2(v1, v2, md):
    return


@overload(__v1_mul_v2, inline="always")
def __v1_mul_v2_impl(v1, v2, md):

    if not isinstance(md, types.Literal):
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        return dot2x2_impl


@jit(nogil=True, nopython=True, cache=True)
def jit_f(v1, v2, md):
    v100, v101, v110, v111 = v1[0], v1[1], v1[2], v1[3]
    v200, v201, v210, v211 = v2[0], v2[1], v2[2], v2[3]

    v300 = (v100*v200 + v101*v210)
    v301 = (v100*v201 + v101*v211)
    v310 = (v110*v200 + v111*v210)
    v311 = (v110*v201 + v111*v211)

    return v300, v301, v310, v311


if __name__ == "__main__":

    n_run = 1000

    v1 = np.ones((4,))
    v2 = np.ones((4,))

    print(jit_f(v1, v2, "full"))
    print(gen_f(v1, v2, "full"))
    print(ovr_f(v1, v2, "full"))

    print("--------jit--------")
    t0 = time.time()
    for i in range(n_run):
        jit_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

    print("---generated_jit---")
    t0 = time.time()
    for i in range(n_run):
        gen_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

    print("------overload-----")
    t0 = time.time()
    for i in range(n_run):
        ovr_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

On my laptop, this produces:

(2.0, 2.0, 2.0, 2.0)
(2.0, 2.0, 2.0, 2.0)
(2.0, 2.0, 2.0, 2.0)
--------jit--------
2.2418498992919923e-06
---generated_jit---
0.0018690543174743651
------overload-----
0.010494077205657959

I am trying to understand the reason for this discrepancy and options to circumvent it. My instinct is that somehow the evaluation of the md.literal_value == something becomes expensive, but I have no evidence to support this claim.

Hi @JSKenyon,

welcome to the board! :slight_smile:

It feels to me as if what you are observing here are not performance problems of the compiled code, but different call overheads for different methods.

I have tried to modify your tests (thanks for the MWE!) a little:
I have defined 3 functions that call the actual implementations with the right arguments, and then njitted them such that calling the top level function from Python should have the same complexity in all cases.

For reasons I cannot seem to figure out, this does not work for the gen_f case, the compilation always fails, so the timings of that are not really comparable to those of jit_f and ovr_f. I wonder if I don’t know how to do this correctly or if it may be a bug that’s conjured up when combining jit, generated_jit and literally ? Might be worth investigating.

Without further ado, here are my timings:

(2.0, 2.0, 2.0, 2.0)
(2.0, 2.0, 2.0, 2.0)
(2.0, 2.0, 2.0, 2.0)
--------jit--------
1.3805700291413814e-07
---generated_jit---
0.0024459507330029735
------overload-----
2.2707700554747134e-07

And here is the modified code I used (I also took the liberty to remove jit_f code duplication).

from numba.extending import overload
from numba import jit, types, generated_jit, literally
import numpy as np
from timeit import timeit

def dot2x2_impl(v1, v2, md):
    v100, v101, v110, v111 = v1[0], v1[1], v1[2], v1[3]
    v200, v201, v210, v211 = v2[0], v2[1], v2[2], v2[3]

    v300 = (v100*v200 + v101*v210)
    v301 = (v100*v201 + v101*v211)
    v310 = (v110*v200 + v111*v210)
    v311 = (v110*v201 + v111*v211)

    return v300, v301, v310, v311

jit_f = jit(nogil=True, nopython=True, cache=True)(dot2x2_impl)

@generated_jit(nogil=True, nopython=True, cache=True, inline="always")
def gen_f(v1, v2, md):
    if not isinstance(md, types.Literal):
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        return dot2x2_impl


@jit(nogil=True, nopython=True, cache=True)
def ovr_f(v1, v2, md):

    return __v1_mul_v2(v1, v2, md)


def __v1_mul_v2(v1, v2, md):
    return


@overload(__v1_mul_v2, inline="always")
def __v1_mul_v2_impl(v1, v2, md):

    if not isinstance(md, types.Literal):
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        return dot2x2_impl


@jit(nogil=True, nopython=True, cache=True)
def test_jit_f():
    return jit_f(v1, v2, "full")

# @jit(nogil=True, nopython=True, cache=True)
def test_gen_f():
    return gen_f(v1, v2, "full")

@jit(nogil=True, nopython=True, cache=True)
def test_ovr_f():
    return ovr_f(v1, v2, "full")


if __name__ == "__main__":

    n_run = 1000

    v1 = np.ones((4,))
    v2 = np.ones((4,))

    print(test_jit_f()) #Also triggers first time compilation :-)
    print(test_gen_f())
    print(test_ovr_f())

    print("--------jit--------")
    t = timeit(test_jit_f, number=n_run)
    print(f"{t/n_run}")


    print("---generated_jit---")
    t = timeit(test_gen_f, number=n_run)
    print(f"{t/n_run}")

    print("------overload-----")
    t = timeit(test_ovr_f, number=n_run)
    print(f"{t/n_run}")

PS: Another point that I found strange while print debugging your examples was that the body of gen_f is executed on every call from python. While I don’t know how exactly generated_jit is designed I found that a bit odd. I have a suspicion this has something to do with the literal typing. Maybe someone else can explain whats going on there.

As additional info concerning my confusion about generated_jit, here is the output I get when adding a print(md) statement on the first line of gen_f (and reducing n_run)

(2.0, 2.0, 2.0, 2.0)
unicode_type
Literal[str](full)
(2.0, 2.0, 2.0, 2.0)
(2.0, 2.0, 2.0, 2.0)
--------jit--------
5.23999915458262e-07
---generated_jit---
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
unicode_type
0.003825946699362248
------overload-----
4.282002919353545e-07

For some reason the body of gen_f is entered on every call, despite the fact that the specialisation should already exist. (If I do the same for the generated_jit example in the docs https://numba.readthedocs.io/en/stable/user/generated-jit.html, the body is only executed the first time a certain signature is called)

Thanks @Hannes!

I can get the generated_jit expression to perform in much the same way by removing the inline=="always" from the generated_jit decorator. I think it is not supported properly, particularly when called from another jitted function.

Overall I think I am happier now. I think some of the problems I was seeing in the more complicated code I am actually implementing stemmed from excessive use of generated_jit - I am now moving towards a factory function based implementation that seems to perform better.

Hi,

ah, yes that does the trick and all the timings are similar then - great :slight_smile:

/cc @stuartarchibald Is the behaviour observed here a known limitation or an issue that should be raised?

Cheers
Hannes

Hi @JSKenyon,

Thanks for sharing this, it’s an unusual problem :slight_smile:

Thanks for debugging this so far @Hannes, I can observe similar behaviours to those you have encountered, which is reassuring :slight_smile:

Augmenting the script as follows:

from numba.extending import overload
from numba import jit, types, generated_jit, literally
import numpy as np
import time


def dot2x2_impl(v1, v2, md):
    v100, v101, v110, v111 = v1[0], v1[1], v1[2], v1[3]
    v200, v201, v210, v211 = v2[0], v2[1], v2[2], v2[3]

    v300 = (v100*v200 + v101*v210)
    v301 = (v100*v201 + v101*v211)
    v310 = (v110*v200 + v111*v210)
    v311 = (v110*v201 + v111*v211)

    return v300, v301, v310, v311


@generated_jit(nopython=True)
def gen_f(v1, v2, md):

    print("GEN F")
    if not isinstance(md, types.Literal):
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        return dot2x2_impl


@jit(nopython=True)
def ovr_f(v1, v2, md):

    return __v1_mul_v2(v1, v2, md)


def __v1_mul_v2(v1, v2, md):
    return


@overload(__v1_mul_v2, inline="always")
def __v1_mul_v2_impl(v1, v2, md):
    print('ol', v1, v2, md)
    if not isinstance(md, types.Literal):
        print("literal escape")
        return lambda v1, v2, md: literally(md)

    if md.literal_value == "full":
        print("correct")
        return dot2x2_impl
    print("fall through")


@jit(nopython=True)
def jit_f(v1, v2, md):
    v100, v101, v110, v111 = v1[0], v1[1], v1[2], v1[3]
    v200, v201, v210, v211 = v2[0], v2[1], v2[2], v2[3]

    v300 = (v100*v200 + v101*v210)
    v301 = (v100*v201 + v101*v211)
    v310 = (v110*v200 + v111*v210)
    v311 = (v110*v201 + v111*v211)

    return v300, v301, v310, v311


if __name__ == "__main__":

    n_run = 100

    v1 = np.ones((4,))
    v2 = np.ones((4,))

    print(jit_f(v1, v2, "full"))
    print(gen_f(v1, v2, "full"))
    print(ovr_f(v1, v2, "full"))

    # Look at the LLVM IR control flow graphs
    jit_f.inspect_cfg(jit_f.signatures[0]).display(view=True)
    gen_f.inspect_cfg(gen_f.signatures[0]).display(view=True)
    ovr_f.inspect_cfg(ovr_f.signatures[0]).display(view=True)

    def show(x, item='inspect_asm'):
        print(str(x).center(80, '-'))
        fn = getattr(x, item)
        print(fn(x.signatures[0]))

    # Look at the disassembly control flow graphs
    show(jit_f, 'inspect_disasm_cfg')
    show(gen_f, 'inspect_disasm_cfg')
    show(ovr_f, 'inspect_disasm_cfg')

    print("--------jit--------")
    t0 = time.time()
    for i in range(n_run):
        jit_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

    # In this case it's evident that a new version is being compiled each time
    print("---generated_jit---")
    t0 = time.time()
    for i in range(n_run):
        gen_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

    # In this case it's evident that a new version is being compiled each time
    # and that the compilation is slower than with generated_jit (expected,
    # there's more to do to resolve overloads).
    print("------overload-----")
    t0 = time.time()
    for i in range(n_run):
        ovr_f(v1, v2, "full")
    t1 = time.time()
    print(f"{(t1-t0)/n_run}")

This shows a few things.

  1. That the LLVM generated for the function is the same no matter which method is used to generated it.

  2. That the machine code generated for the function is the same no matter which method is used to generated it (dump is below):

    ---------------CPUDispatcher(<function jit_f at 0x7f2f854a5310>)----------------
    [0x08000040]>  # method.__main__.jit_f_241_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___unicode_type (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h);
    ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
    │  0x8000040                                                                                                                                                                                             │
    │ ; [02] -r-x section size 1208 named .text                                                                                                                                                              │
    │   ;-- section..text:                                                                                                                                                                                   │
    │   ;-- reloc.__main__::jit_f$241(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, unicode_type):                                                                           │
    │   ;-- .text:                                                                                                                                                                                           │
    │   ;-- __main__::jit_f$241(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, unicode_type):                                                                                 │
    │   ;-- rip:                                                                                                                                                                                             │
    │ 74: method.__main__.jit_f_241_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___unicode_type (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h); │
    │ ; arg int64_t arg_8h @ rsp+0x8                                                                                                                                                                         │
    │ ; arg int64_t arg_40h @ rsp+0x40                                                                                                                                                                       │
    │ ; arg int64_t arg7 @ xmm0                                                                                                                                                                              │
    │ ; arg int64_t arg9 @ xmm2                                                                                                                                                                              │
    │ ; arg int64_t arg10 @ xmm3                                                                                                                                                                             │
    │ mov rax, qword [arg_8h]                                                                                                                                                                                │
    │ mov rcx, qword [arg_40h]                                                                                                                                                                               │
    │ movupd xmm0, xmmword [rcx]                                                                                                                                                                             │
    │ movupd xmm1, xmmword [rcx + 0x10]                                                                                                                                                                      │
    │ movddup xmm2, qword [rax]                                                                                                                                                                              │
    │ mulpd xmm2, xmm0                                                                                                                                                                                       │
    │ movddup xmm3, qword [rax + 8]                                                                                                                                                                          │
    │ mulpd xmm3, xmm1                                                                                                                                                                                       │
    │ ; arg10                                                                                                                                                                                                │
    │ addpd xmm3, xmm2                                                                                                                                                                                       │
    │ movddup xmm2, qword [rax + 0x10]                                                                                                                                                                       │
    │ mulpd xmm2, xmm0                                                                                                                                                                                       │
    │ movddup xmm0, qword [rax + 0x18]                                                                                                                                                                       │
    │ mulpd xmm0, xmm1                                                                                                                                                                                       │
    │ ; arg9                                                                                                                                                                                                 │
    │ addpd xmm0, xmm2                                                                                                                                                                                       │
    │ movupd xmmword [rdi], xmm3                                                                                                                                                                             │
    │ movupd xmmword [rdi + 0x10], xmm0                                                                                                                                                                      │
    │ xor eax, eax                                                                                                                                                                                           │
    │ ret                                                                                                                                                                                                    │
    └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
    
    ---------------CPUDispatcher(<function gen_f at 0x7f2f857973a0>)----------------
    [0x08000040]>  # method.__main__.dot2x2_impl_243_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___Literal_5bstr_5d_28full_29 (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h);
    ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
    │  0x8000040                                                                                                                                                                                                                 │
    │ ; [02] -r-x section size 984 named .text                                                                                                                                                                                   │
    │   ;-- section..text:                                                                                                                                                                                                       │
    │   ;-- reloc.__main__::dot2x2_impl$243(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, Literal$5bstr$5d$28full$29):                                                                           │
    │   ;-- .text:                                                                                                                                                                                                               │
    │   ;-- __main__::dot2x2_impl$243(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, Literal$5bstr$5d$28full$29):                                                                                 │
    │   ;-- rip:                                                                                                                                                                                                                 │
    │ 74: method.__main__.dot2x2_impl_243_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___Literal_5bstr_5d_28full_29 (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h); │
    │ ; arg int64_t arg_8h @ rsp+0x8                                                                                                                                                                                             │
    │ ; arg int64_t arg_40h @ rsp+0x40                                                                                                                                                                                           │
    │ ; arg int64_t arg7 @ xmm0                                                                                                                                                                                                  │
    │ ; arg int64_t arg9 @ xmm2                                                                                                                                                                                                  │
    │ ; arg int64_t arg10 @ xmm3                                                                                                                                                                                                 │
    │ mov rax, qword [arg_8h]                                                                                                                                                                                                    │
    │ mov rcx, qword [arg_40h]                                                                                                                                                                                                   │
    │ movupd xmm0, xmmword [rcx]                                                                                                                                                                                                 │
    │ movupd xmm1, xmmword [rcx + 0x10]                                                                                                                                                                                          │
    │ movddup xmm2, qword [rax]                                                                                                                                                                                                  │
    │ mulpd xmm2, xmm0                                                                                                                                                                                                           │
    │ movddup xmm3, qword [rax + 8]                                                                                                                                                                                              │
    │ mulpd xmm3, xmm1                                                                                                                                                                                                           │
    │ ; arg10                                                                                                                                                                                                                    │
    │ addpd xmm3, xmm2                                                                                                                                                                                                           │
    │ movddup xmm2, qword [rax + 0x10]                                                                                                                                                                                           │
    │ mulpd xmm2, xmm0                                                                                                                                                                                                           │
    │ movddup xmm0, qword [rax + 0x18]                                                                                                                                                                                           │
    │ mulpd xmm0, xmm1                                                                                                                                                                                                           │
    │ ; arg9                                                                                                                                                                                                                     │
    │ addpd xmm0, xmm2                                                                                                                                                                                                           │
    │ movupd xmmword [rdi], xmm3                                                                                                                                                                                                 │
    │ movupd xmmword [rdi + 0x10], xmm0                                                                                                                                                                                          │
    │ xor eax, eax                                                                                                                                                                                                               │
    │ ret                                                                                                                                                                                                                        │
    └────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
    
    ---------------CPUDispatcher(<function ovr_f at 0x7f2f854a5040>)----------------
    [0x08000040]>  # method.__main__.ovr_f_249_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___Literal_5bstr_5d_28full_29 (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h);
    ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┐
    │  0x8000040                                                                                                                                                                                                           │
    │ ; [02] -r-x section size 984 named .text                                                                                                                                                                             │
    │   ;-- section..text:                                                                                                                                                                                                 │
    │   ;-- reloc.__main__::ovr_f$249(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, Literal$5bstr$5d$28full$29):                                                                           │
    │   ;-- .text:                                                                                                                                                                                                         │
    │   ;-- __main__::ovr_f$249(Array<double, 1, C, mutable, aligned>, Array<double, 1, C, mutable, aligned>, Literal$5bstr$5d$28full$29):                                                                                 │
    │   ;-- rip:                                                                                                                                                                                                           │
    │ 74: method.__main__.ovr_f_249_Array_double__1__C__mutable__aligned___Array_double__1__C__mutable__aligned___Literal_5bstr_5d_28full_29 (int64_t arg7, int64_t arg9, int64_t arg10, int64_t arg_8h, int64_t arg_40h); │
    │ ; arg int64_t arg_8h @ rsp+0x8                                                                                                                                                                                       │
    │ ; arg int64_t arg_40h @ rsp+0x40                                                                                                                                                                                     │
    │ ; arg int64_t arg7 @ xmm0                                                                                                                                                                                            │
    │ ; arg int64_t arg9 @ xmm2                                                                                                                                                                                            │
    │ ; arg int64_t arg10 @ xmm3                                                                                                                                                                                           │
    │ mov rax, qword [arg_8h]                                                                                                                                                                                              │
    │ mov rcx, qword [arg_40h]                                                                                                                                                                                             │
    │ movupd xmm0, xmmword [rcx]                                                                                                                                                                                           │
    │ movupd xmm1, xmmword [rcx + 0x10]                                                                                                                                                                                    │
    │ movddup xmm2, qword [rax]                                                                                                                                                                                            │
    │ mulpd xmm2, xmm0                                                                                                                                                                                                     │
    │ movddup xmm3, qword [rax + 8]                                                                                                                                                                                        │
    │ mulpd xmm3, xmm1                                                                                                                                                                                                     │
    │ ; arg10                                                                                                                                                                                                              │
    │ addpd xmm3, xmm2                                                                                                                                                                                                     │
    │ movddup xmm2, qword [rax + 0x10]                                                                                                                                                                                     │
    │ mulpd xmm2, xmm0                                                                                                                                                                                                     │
    │ movddup xmm0, qword [rax + 0x18]                                                                                                                                                                                     │
    │ mulpd xmm0, xmm1                                                                                                                                                                                                     │
    │ ; arg9                                                                                                                                                                                                               │
    │ addpd xmm0, xmm2                                                                                                                                                                                                     │
    │ movupd xmmword [rdi], xmm3                                                                                                                                                                                           │
    │ movupd xmmword [rdi + 0x10], xmm0                                                                                                                                                                                    │
    │ xor eax, eax                                                                                                                                                                                                         │
    │ ret                                                                                                                                                                                                                  │
    └──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────┘
    
  3. That the generated_jit and overload cases are recompiling every time. I think it’s a bad interaction with literally and the in memory compilation cache that’s causing it. Probably that there is caching on non-literal types but the overload is forcing literal. I am fairly convinced that this is the root cause of the performance difference, it’s not the execution time of the function that’s causing it, but that new ones keep on being compiled for each execution.

Essentially, this is a bug :slight_smile: Excessive recompilation due to use of `literally` and potentially unaware in memory cache. · Issue #6956 · numba/numba · GitHub

Thanks for the follow up @stuartarchibald ! :slight_smile: