Profiling with a decorator and @njit

Thanks folks!

@luk-f-a Yes, that’s exactly the issue.

Here’s code that demonstrates the issue in a mini use case.

import time
import numpy as np
from numba import njit, jit

results = {}

def t(f):
    def wrapper(*args, **kwargs):
        start = time.time()
        for i in range(1):
            g = f(*args)
        run_time = time.time() - start
        if f.__name__ in results:
            results[f.__name__] += [run_time]
        else:
            results[f.__name__] = [run_time]
        # print( time.time() - start, f.__name__ )
        return g
    return wrapper

# @t #This line causes an issue
@njit
def ahp(x, t, u, A):
    return -A * np.sum( np.exp( (x-t) / u ) )

@njit
def line_intercept(y1, y2, thresh):
    #this is really a ratio
    # print(y1, y2)
    return 1 - (thresh-y1)/(y2-y1)

@t
@njit
def get_spikes(c, threshold, u, A):
    st = []
    y = np.zeros(c.size)
    for xt in range(0, c.size):
        a = ahp(np.asarray(st), xt, u, A)
        y[xt] = c[xt] + a
        if threshold < y[xt]:
            li = line_intercept(y[xt-1], y[xt], threshold)
            st.append(xt-li)
            # st.append(xt)
    return (np.asarray(st), y)

@t
def test_sim():
    np.random.seed(0)
    x = np.random.uniform(-1,1,200)
    st, y = get_spikes(x, np.max(x)/2, 1.2, 100)

def profile_results():
    l = []
    for k in results:
        a = np.asarray(results[k])
        # print(k, ' '*(15-len(k)), np.sum(a[1:]))
        l += [[k+' '*(13-len(k)), np.sum(a[1:])]]
    l = sorted(l, key=lambda x: x[1])
    for i in range(len(l)):
        print( l[i][0], "{:.6f}".format( l[i][1] ) )

if __name__ == '__main__':
    test_sim()
    test_sim()
    profile_results()