Profiling with a decorator and @njit

I created a decorator @t that I add before @njit. A problem arises when one function using @njit calls another with both @t and @njit. I end up removing the decorator in these instances. Otherwise I get this error: Untyped global name 'example_function': cannot determine Numba type of <class 'function'>. I’d like to be able to continue seeing the results for each function.

results = {}
def t(f):
    def wrapper(*args, **kwargs):
        start = time.time()
        for i in range(1):
            g = f(*args)
        run_time = time.time() - start
        if f.__name__ in results:
            results[f.__name__] += [run_time]
        else:
            results[f.__name__] = [run_time]
        return g
    return wrapper

def profile_results():
    l = []
    for k in results:
        a = np.asarray(results[k])
        # print(k, ' '*(15-len(k)), np.sum(a[1:]))
        l += [[k+' '*(13-len(k)), np.sum(a[1:])]]
    l = sorted(l, key=lambda x: x[1])
    for i in range(len(l)):
        print( l[i][0], "{:.6f}".format( l[i][1] ) )

I imagine you’re doing something like:

@njit
@t
def f():
    pass

does it work if you do:

@t
@njit
def f():
    pass

?

Nope, I was doing @t and then @njit.

I believe the pattern in general is compatible, as I have heard of other people combining decorators with @njit - If you could post an executable reproducer I could look into this further a bit more easily.

1 Like

A problem arises when one function using @njit calls another with both @njit and @t.

I think this is the key: An njit function cannot call an uncompiled decorator.

2 Likes

@randompast I’m guessing you’re doing something like this:

@t       
@njit     
def foo():
    pass

@njit
def bar():
    return foo()

You get an error because while njit(foo) is a valid jitted function (CPUDispatcher in Numba terminology), t(njit(foo)) is a normal Python function and cannot be called from bar, which is jitted.

I’m very interested in this problem, since there’s no other way to measure the per-function execution time. @gmarkall is there any way to have a global dictionary? Unfortunately typed dicts cannot be lowered as constants (https://github.com/numba/numba/issues/4062#issuecomment-490488298). I wonder if there’s any trick that could be used to get a reference to a global dictionary.

2 Likes

Got it, I missed that part - thanks!

1 Like

Thanks folks!

@luk-f-a Yes, that’s exactly the issue.

Here’s code that demonstrates the issue in a mini use case.

import time
import numpy as np
from numba import njit, jit

results = {}

def t(f):
    def wrapper(*args, **kwargs):
        start = time.time()
        for i in range(1):
            g = f(*args)
        run_time = time.time() - start
        if f.__name__ in results:
            results[f.__name__] += [run_time]
        else:
            results[f.__name__] = [run_time]
        # print( time.time() - start, f.__name__ )
        return g
    return wrapper

# @t #This line causes an issue
@njit
def ahp(x, t, u, A):
    return -A * np.sum( np.exp( (x-t) / u ) )

@njit
def line_intercept(y1, y2, thresh):
    #this is really a ratio
    # print(y1, y2)
    return 1 - (thresh-y1)/(y2-y1)

@t
@njit
def get_spikes(c, threshold, u, A):
    st = []
    y = np.zeros(c.size)
    for xt in range(0, c.size):
        a = ahp(np.asarray(st), xt, u, A)
        y[xt] = c[xt] + a
        if threshold < y[xt]:
            li = line_intercept(y[xt-1], y[xt], threshold)
            st.append(xt-li)
            # st.append(xt)
    return (np.asarray(st), y)

@t
def test_sim():
    np.random.seed(0)
    x = np.random.uniform(-1,1,200)
    st, y = get_spikes(x, np.max(x)/2, 1.2, 100)

def profile_results():
    l = []
    for k in results:
        a = np.asarray(results[k])
        # print(k, ' '*(15-len(k)), np.sum(a[1:]))
        l += [[k+' '*(13-len(k)), np.sum(a[1:])]]
    l = sorted(l, key=lambda x: x[1])
    for i in range(len(l)):
        print( l[i][0], "{:.6f}".format( l[i][1] ) )

if __name__ == '__main__':
    test_sim()
    test_sim()
    profile_results()

This help?

import time
import numpy as np
from numba import njit, jit, objmode

results = {}

def jit_timer(f):
    jf = njit(f)
    @njit
    def wrapper(*args):
        with objmode(start='float64'):
            start = time.time()
        g = jf(*args)
        with objmode():
            end = time.time()
            run_time = end - start
            if f.__name__ in results:
                results[f.__name__] += [run_time]
            else:
                results[f.__name__] = [run_time]
        return g
    return wrapper

@njit
def pointless_delay(seconds):
    with objmode():
        s = time.time()
        e = 0
        while (e < seconds):
            e = time.time() - s

@jit_timer
def ahp(x, t, u, A):
    pointless_delay(1) # 1s delay
    # total delay is 1s

@jit_timer
def line_intercept(y1, y2, thresh):
    pointless_delay(1) # 1s delay
    # total delay is 1s

@jit_timer
def get_spikes(c, threshold, u, A):
    pointless_delay(2) # 2s delay
    ahp(None, None, None, None) # 1s delay
    line_intercept(None, None, None) # 1s delay
    # total delay is 4s

@jit_timer
def test_sim():
    pointless_delay(7) # 7s delay
    get_spikes(None, 2, 1.2, 100) # 4s delay
    # total delay is 11s

def profile_results():
    l = []
    for k in results:
        a = np.asarray(results[k])
        l += [[k+' '*(13-len(k)), np.sum(a[1:])]]
    l = sorted(l, key=lambda x: x[1])
    for i in range(len(l)):
        print( l[i][0], "{:.6f}".format( l[i][1] ) )

if __name__ == '__main__':
    test_sim()
    test_sim()
    profile_results()

gives me:

$ python di1.py 
line_intercept 1.000044
ahp           1.000045
get_spikes    4.000300
test_sim      11.000389
3 Likes

duh, I should have thought about it. Thanks @stuartarchibald!

I’m finding with objmode very useful lately.

1 Like

@stuartarchibald, yes! That definitely helps, thanks so much! I didn’t know about object mode. This is a cool way to use it. I just checked and it works and now I can see the results of all of those other functions too. I like this @jit_timer!

I played with the timer a bit more and realized I could use very similar code to make a dependency graph with it too. Thought you folks might enjoy this.

results = {}
tree = {'stack':['main'], 'main':set()}

def wrapper_objm_start(f):
    start = time.time()
    tree[ tree['stack'][-1] ].add( f.__name__ )
    tree['stack'] += [ f.__name__ ]
    if f.__name__ not in results:
        tree[f.__name__] = set()
        # print(tree['stack'])
    return start

def wrapper_objm_end(f, start):
    run_time = time.time() - start
    if f.__name__ in results:
        results[f.__name__] += [run_time]
    else:
        results[f.__name__] = [run_time]
    tree['stack'] = tree['stack'][:-1]

def t(f):
    def wrapper(*args, **kwargs):
        start = wrapper_objm_start(f)
        g = f(*args)
        wrapper_objm_end(f, start)
        return g
    return wrapper

def jit_timer(f):
    jf = njit(f)
    @njit
    def wrapper(*args):
        with objmode(start='float64'):
            start = wrapper_objm_start(f)
        g = jf(*args)
        with objmode():
            wrapper_objm_end(f, start)
        return g
    return wrapper

Here’s some output from a sample run of mine. You can see that I can run different implementations of the same functionality (conv_1d1o and nbconv_1d1o) and compare their run time.

0.003904  test_sim
0.003134 -|- sim_1D
0.001564 -|--|- gcsrm_1D
0.000029 -|--|--|- nbconv_1d1o
0.000047 -|--|--|- conv_1d1o
0.001256 -|--|--|- nbconv_1d2o
0.000079 -|--|--|- get_spikes_v2
0.000003 -|--|--|--|- line_intercept
0.000280 -|--|- partials
0.000154 -|--|--|- compute_dtdbd
0.000068 -|--|--|--|- td
0.000005 -|--|--|--|--|- outer
0.000002 -|--|--|--|--|- intk
0.000061 -|--|--|--|- intk_all
0.000002 -|--|--|--|--|- intk
0.000004 -|--|--|- dndu
0.000005 -|--|--|- dndt
0.000055 -|--|--|- get_d
0.000002 -|--|--|--|- intk
0.000006 -|--|--|--|- int2d1dk
0.000032 -|--|- genK_nD_rs

The code to print the tree with the times can be done simply. It’d be neat to have a decorator built into numba that did something like this, especially since the way it works now is limited because of the global variables.

def print_tree(node, layer):
    for n in node:
        print('{:.6f}'.format( np.min(results[n]) ), '-|-'*layer, n)
        print_tree(tree[n], layer+1)

Cool stuff :slight_smile:
I think this can definitely be helpful in certain situations, even just for following call chains alone.

I guess that one has to keep in mind that for short running (and repeatedly called) functions the decorator itself could dominate the runtime, especially due to the context switching back into the interpreter. Possibly this also blocks certain compiler optimisations like inlining?.

@randompast very nice :slight_smile: expect others will find this or similar of use!

@Hannes indeed, your assessment is correct, coarse grained functions will work best with the approach above, it will also likely have knock-on effects in optimisation which may also skew results. For more fine grained/general profiling something like this https://github.com/numba/numba/issues/5028 is needed. I just need a decent tranche of time to get it into production as it’s somewhat involved!

@stuartarchibald, would the line profiling in #5028 also provide function level profiling? Or would those be two different features?

@randompast the tree is a great addition! If I have time, I’ll try to add some information about the number of times a function is compiled (since excessive compilation can also be a performance problem). I think this decorator is a perfect example of the kind of things to put in the add-ons package (Addons / Contrib repo).

I would like to add that I added a simple modification to the tree print also print the ratio of the overall runtime.

Thanks for the encouragement, additional insights, and reference to 5028!

This decorator has been increasingly causing me trouble. Since it uses global variables it can’t use the cached version. Each time the file is run, the njit functions need to be recompiled.

This decorator has been increasingly causing me trouble. Since it uses global variables it can’t use the cached version. Each time the file is run, the njit functions need to be recompiled.

I think https://github.com/numba/numba/pull/5674 fixes, it’ll make it into 0.51.0.

@luk-f-a I believe so, it’d also track timings even when LLVM’s inliner aggressively inlines functions into each other for speed!

1 Like

thanks for the information Stuart!

@everyone: does anyone know if this objmodel call below could be avoided? Is there a way to get the clock time from jitted code?

        with objmode(start='float64'):
            start = time.time()
1 Like