3x slowdown in parent function when applying njit

Disclaimer: Resolution of this example is likely of extremely limited personal use to me since I can just pick the faster of the two versions. However, I thought that it might be helpful for numba core development purposes since njit should make things faster and seeing cases where it doesn’t seemed helpful. If not, feel free to delete.

Here are the results. Notice that the njitted version (get_d_2) is faster, but the parent function becomes significantly slower.

 $  time python3 partials.py 
using numpy
 0.005436 0.758  partials_2
 0.001049 0.146 -|- get_d_2
 0.000590 0.082 -|--|- intk1d3o
 0.000154 0.022 -|--|- intk1d1o
 0.000286 0.040 -|--|- intk1d2o
 0.001445 0.201  partials_1
 0.001322 0.184 -|- get_d_1
 0.000590 0.082 -|--|- intk1d3o
 0.000154 0.022 -|--|- intk1d1o
 0.000286 0.040 -|--|- intk1d2o

real	0m6.967s
user	0m6.895s
sys	0m0.227s

Here’s all of the relevant code.

#partials.py
from cupy_or_numpy import xp
from decorators import get_jit_decorator, get_decorator, reset_globals, profile_results
from intk import intk1d1o, intk1d2o, intk1d3o

@get_decorator()
def get_d_1(K, x, dx, xt):
    d = K[0]
    if 0 < K[1].size:
        d += intk1d1o(K[1], dx, xt)
    if 0 < K[2].size:
        d += intk1d2o(K[2], x, dx, xt)
        d += intk1d2o(K[2], dx, x, xt)
    if 0 < K[3].size:
        d += intk1d3o(K[3], dx, x, x, xt)
        d += intk1d3o(K[3], x, dx, x, xt)
        d += intk1d3o(K[3], x, x, dx, xt)
    return d if d!=0 else 1

@get_jit_decorator()
def get_d_2(K0, K1, K2, K3, x, dx, xt):
    d = K0
    if 0 < K1.size:
        d += intk1d1o(K1, dx, xt)
    if 0 < K2.size:
        d += intk1d2o(K2, x, dx, xt)
        d += intk1d2o(K2, dx, x, xt)
    if 0 < K3.size:
        d += intk1d3o(K3, dx, x, x, xt)
        d += intk1d3o(K3, x, dx, x, xt)
        d += intk1d3o(K3, x, x, dx, xt)
    return d if d!=0 else 1

@get_decorator()
def partials_1(x, dx, K, st):
    xt = int(st[-1]+1)
    d = get_d_1(K, x, dx, xt)

@get_decorator()
def partials_2(x, dx, K, st):
    xt = int(st[-1]+1)
    d = get_d_2(K[0], K[1], K[2], K[3], x, dx, xt)

@get_decorator()
def main(n):
    x = xp.random.uniform(-1,1,400)
    dx = xp.zeros(x.size)
    dx[1:] = xp.diff(x)
    s = 10
    K = [1, xp.random.uniform(-1,1,s)
        , xp.random.uniform(-1,1,(s,s))
        , xp.random.uniform(-1,1,(s,s,s))
        ]
    st = xp.arange(n)*15 + xp.random.uniform(0,1,n)*10
    for i in range(n):
        partials_1(x, dx, K, st[:i+1])
        partials_2(x, dx, K, st[:i+1])

if __name__ == '__main__':
    main(10)
    reset_globals()
    main(20)
    profile_results()

####################################
####################################
#intk.py
from decorators import get_jit_decorator

@get_jit_decorator()
def intk1d1o(k, x, tl):
    s = 0
    for i in range(k.size):
        s += k[i] * x[tl-i]
    return s

@get_jit_decorator()
def intk1d2o(k, x1, x2, tl):
    s = 0
    for i in range(k.shape[0]):
        for j in range(k.shape[1]):
            s += k[i,j] * x1[tl-i] * x2[tl-j]
    return s

@get_jit_decorator()
def intk1d3o(k, x1, x2, x3, tl):
    s = 0
    for i in range(k.shape[0]):
        for j in range(k.shape[1]):
            for l in range(k.shape[2]):
                s += k[i,j,l] * x1[tl-i] * x2[tl-j] * x3[tl-l]
    return s

####################################
####################################
#decorators.py
import time
from numba import njit, jit, objmode
from cupy_or_numpy import xp

USE_TIMER = True
results = {}
tree = {'stack':['main'], 'main':set()}

def wrapper_objm_start(f):
    start = time.time()
    tree[ tree['stack'][-1] ].add( f.__name__ )
    tree['stack'] += [ f.__name__ ]
    if f.__name__ not in results:
        tree[f.__name__] = set()
        # print(tree['stack'])
    return start

def wrapper_objm_end(f, start):
    run_time = time.time() - start
    if f.__name__ in results:
        results[f.__name__] += [run_time]
    else:
        results[f.__name__] = [run_time]
    tree['stack'] = tree['stack'][:-1]

def timer(f):
    def wrapper(*args, **kwargs):
        start = wrapper_objm_start(f)
        g = f(*args)
        wrapper_objm_end(f, start)
        return g
    return wrapper

def timer_none(f):
    def wrapper(*args):
        return f(*args)
    return wrapper

def jit_timer(f):
    jf = njit(f)
    @njit(cache=False)
    def wrapper(*args):
        with objmode(start='float64'):
            start = wrapper_objm_start(f)
        g = jf(*args)
        # g = f(*args)
        with objmode():
            wrapper_objm_end(f, start)
        return g
    return wrapper

def get_jit_decorator():
    if USE_TIMER:
        # return timer
        return jit_timer
    else:
        return njit

def get_decorator():
    if USE_TIMER:
        return timer
    else:
        return timer_none

def reset_globals():
    global results
    results = {}
    global tree
    tree = {'stack':['main'], 'main':set()}

def print_tree(node, layer):
    for n in node:
        rt = xp.sum(results[n])
        rtr = rt / xp.sum(results['main'])
        print('{0:>9.6f} {1:.03f}'.format( rt, rtr ), '-|-'*layer, n)
        print_tree(tree[n], layer+1)

def profile_results():
    # print(results)
    # print(tree)
    l = []
    for k in results:
        a = xp.asarray(results[k])
        # l += [[k+' '*(17-len(k)), xp.sum(a[1:])]]
        l += [[k+' '*(17-len(k)), xp.sum(a)]]
    l = sorted(l, key=lambda x: x[1])
    # for i in range(len(l)):
    #     print(  '{:.6f}'.format( l[i][1] ), l[i][0] )
        # print( l[i][0], "{:.6f}".format( l[i][1] ) )
    print_tree(tree['main'], 0)

####################################
####################################
#cupy_or_numpy.py
try:
    import cupy as xp
    print("using cupy")
except:
    import numpy as xp
    print("using numpy")

Odd, it definitely compiles and runs for me. I can’t upload the files directly. I’m using python 3.6.9 and numba 0.50.0.

Well, I didn’t split in files, and I got rid of all the decorator distraction (hopefully correctly). This is the code:

import numpy as xp
from numba import njit

def get_d_1(K, x, dx, xt):
    d = K[0]
    if 0 < K[1].size:
        d += intk1d1o(K[1], dx, xt)
    if 0 < K[2].size:
        d += intk1d2o(K[2], x, dx, xt)
        d += intk1d2o(K[2], dx, x, xt)
    if 0 < K[3].size:
        d += intk1d3o(K[3], dx, x, x, xt)
        d += intk1d3o(K[3], x, dx, x, xt)
        d += intk1d3o(K[3], x, x, dx, xt)
    return d if d!=0 else 1

@njit
def get_d_2(K0, K1, K2, K3, x, dx, xt):
    d = K0
    if 0 < K1.size:
        d += intk1d1o(K1, dx, xt)
    if 0 < K2.size:
        d += intk1d2o(K2, x, dx, xt)
        d += intk1d2o(K2, dx, x, xt)
    if 0 < K3.size:
        d += intk1d3o(K3, dx, x, x, xt)
        d += intk1d3o(K3, x, dx, x, xt)
        d += intk1d3o(K3, x, x, dx, xt)
    return d if d!=0 else 1


def partials_1(x, dx, K, st):
    xt = int(st[-1]+1)
    d = get_d_1(K, x, dx, xt)

@njit
def partials_2(x, dx, K, st):
    xt = int(st[-1]+1)
    d = get_d_2(K[0], K[1], K[2], K[3], x, dx, xt)


@njit
def intk1d1o(k, x, tl):
    s = 0
    for i in range(k.size):
        s += k[i] * x[tl-i]
    return s

@njit
def intk1d2o(k, x1, x2, tl):
    s = 0
    for i in range(k.shape[0]):
        for j in range(k.shape[1]):
            s += k[i,j] * x1[tl-i] * x2[tl-j]
    return s

@njit
def intk1d3o(k, x1, x2, x3, tl):
    s = 0
    for i in range(k.shape[0]):
        for j in range(k.shape[1]):
            for l in range(k.shape[2]):
                s += k[i,j,l] * x1[tl-i] * x2[tl-j] * x3[tl-l]
    return s

and tested as

%%timeit
n=10
x = xp.random.uniform(-1,1,400)
dx = xp.zeros(x.size)
dx[1:] = xp.diff(x)
s = 10
K = [1, xp.random.uniform(-1,1,s)
    , xp.random.uniform(-1,1,(s,s))
    , xp.random.uniform(-1,1,(s,s,s))
    ]
st = xp.arange(n)*15 + xp.random.uniform(0,1,n)*10
for i in range(n):
    partials_1(x, dx, K, st[:i+1])

#121 µs ± 307 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

and

%%timeit
x = xp.random.uniform(-1,1,400)
dx = xp.zeros(x.size)
dx[1:] = xp.diff(x)
s = 10
K = (1, xp.random.uniform(-1,1,s)
    , xp.random.uniform(-1,1,(s,s))
    , xp.random.uniform(-1,1,(s,s,s))
    )
st = xp.arange(n)*15 + xp.random.uniform(0,1,n)*10
for i in range(n):
    partials_2(x, dx, K, st[:i+1])

# 89.1 µs ± 137 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

I think there are several things going on: first, the profiler decorator is an approximation and not an accurate measurement. Second, and more in general, for a short-running function like this, there are no guarantees that jitted code will be faster. There’s a price to pay for getting machine code: compilation and type inference. The first one is paid only once, so it can be eliminated from benchmarks. The second cannot. There’s a cost every time you cross the interpreted-compiled boundary (ie call a jitted function). If you look at gitter from last week, I mentioned that I had a function that spent 50% of its time doing type inference on its inputs. It was so fast, and it was called so many times per second, that the small cost started to accumulate.

So, my take is: we should compile things that take a second or more to run.

Thanks for the insights. And, you’re right, there’s some serious overhead as these functions are called more frequently. This partials function is potentially called hundreds of thousands of times so microseconds add up. So, I figured it’d be worth trying to throw njit at it. The idea of crossing the interpreter-compiled boundary is interesting since the real work is done in the other functions like intk1d3o. I see you don’t have the same issue with timing on your machine. When I ran it without the timer and $time python3 partials.py I got about 40% slower with the njit version, since this is only about 10% of the run speed of the overall simulation.

The best solution to this issue (wherever reasonably possible) is to compile as “high up” in the call chain as possible. Type inference is only done when a jitted function is called from interpreted code, but everything is static when calling jitted code from within other jitted code (afaik). So you can get significant speed ups by compiling the parent function, even if it does no heavy lifting itself.

1 Like

the cost of type inference on short running functions it’s something I didn’t realize fully until last week. The problem is that going too high up the chain likely hits features that cannot be jitted, so there’s a balancing act. I would say that if a function is called in a loop more than 1000 times per second, the loop probably should be jitted. Even if it means re-factoring the parent function to split the loop from the non-jitable parts.
As @Hannes said, the type inference is paid once per call, so grouping more code under one njit reduces the impact.
Also important to note is that different types have different inference costs. In my case, I had nested tuples with Dispatchers, which is one of the worse combinations. Arrays, int, floats, have lower type inference times.

Yes the balancing act is real, I have made the painful encounter with non-jittable callers on a few occasions :’)

An added advantage of compiling higher up may be additional compiler optimisations (specifically inlining and maybe even loop fusion?)

Yeah, and that weighs against the flexibility. I have a variable that represents an infinite series of vectors of increasing shape, K = [1, (10), (10,10), (10,10,10)]. I arbitrarily stop at 3rd order, but it’s nicer to use K as a variable without going into every function and typing K0, K1, K2, K3. And, yeah, I’ve noticed njit to be most useful when there are loops. I’m somewhat tempted to go all the way up and try to njit the entire simulation and maybe use cuda.njit too while I’m at it. But, doing so requires sticking to an API that looks like this (K0, K1, K2, K3, b0, b1, b2, b3) instead of (K, b) as you can see a bit of an example of in this post.

have you tried Siu’s trick of setting the other dimensions to 1? ie

K = [np.ones(0), 
       np.random.uniform((1,1,10)), 
       np.random.uniform((1,10,10)),
       np.random.uniform((10,10,10))
]

in that way you should be able to compile the whole loop.

I haven’t yet. It’s not as simple as (1,1,10) since it’s really the result of a spline function. That said, I used the profile tree to identify which functions were jitted and which weren’t. I essentially had everything jitted, so I’m crawling up the line and just reworking the api using K1, K2, K3.

Here’s the relevant code for the spline function in numpy format. The njit version is significantly longer.

@get_decorator()
def genK_nD_rs(B,b): #rectangular solid
    n = len(B)
    s = [chr(i+97) for i in range(0,2*n,2)]
    f = [chr(i+98) for i in range(0,2*n,2)]
    e = ",".join([i[0]+i[1] for i in list(zip(s,f))])
    e += ","+"".join(s)+"->"+"".join(f)
    return xp.einsum(e, *B, b.reshape([i.shape[0] for i in B]))

Hm, looks like I can do a.reshape((1,1,10)) so what I said about that is moot.

Crawled up 2 more functions to minimize the the interpreter-compiled boundary and instead of 3.5x slower, it’s 18% faster.

2 Likes