Profiling with a decorator and @njit

Non-portably, for linux:

from numba import njit, objmode
import numpy as np
import ctypes
import time

CLOCK_MONOTONIC = 0x1
clock_gettime_proto = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int,
                                       ctypes.POINTER(ctypes.c_long))
pybind = ctypes.CDLL(None)
clock_gettime_addr = pybind.clock_gettime
clock_gettime_fn_ptr = clock_gettime_proto(clock_gettime_addr)


@njit
def timenow():
    timespec = np.zeros(2, dtype=np.int64)
    clock_gettime_fn_ptr(CLOCK_MONOTONIC, timespec.ctypes)
    ts = timespec[0]
    tns = timespec[1]
    return np.float64(ts) + 1e-9 * np.float64(tns)


@njit
def pointless_delay(seconds):
    with objmode():
        s = time.time()
        e = 0
        while (e < seconds):
            e = time.time() - s


@njit
def do_stuff(n):
    t0 = timenow()
    pointless_delay(n)
    print("Elapsed", timenow() - t0)


do_stuff(1)
do_stuff(2)
do_stuff(3.21)

to do this portably it’ll need support adding in Numba, probably in the form of a translation of the CPython internals for the time module.

3 Likes

I think the biggest problem is not that it needs to recompile, but that everything needs to be in one file so that I can pass the global object around. Can I apply the decorator to the import somehow? Would the following be the right thing to do?

from test_functions import do_stuff
do_stuff = jit_timer(do_stuff) #, results, tree)  #maybe include the globals?

I’m trying to split the simulation into multiple files. I thought that I couldn’t do that because of the globals, but there might be a way. It’d be able to toggle just the decorator, but since the functions call each other, it doesn’t seem possible to do in this way. One neat approach might be to import functions in a way that applies the decorator to each function within it: from dEdt import * using decorator. Maybe https://github.com/numba/numba/issues/5028 would solve some of these issues?

#decorator_test.py
from numba import njit, jit, objmode
from decorators import timer, jit_timer, reset_globals, profile_results
from cupy_or_numpy import xp
from dEdt import DE_dEdt_njit, dEdt_njit

#decorator = jit_timer
decorator = njit

dEdt_njit = decorator(dEdt_njit)
DE_dEdt_njit = decorator(DE_dEdt_njit)

@timer
def test():
    tO = xp.arange(10)*10 + xp.random.uniform(0,4,10)
    tD = xp.arange(10)*10 + xp.random.uniform(0,4,10)
    return DE_dEdt_njit(tD, tO, 150)

@timer
def main():
    a = test()
    print(a)

if __name__ == '__main__':
    reset_globals()
    main()
    profile_results()



#cupy_or_numpy.py
try:
    import cupy as xp
    print("using cupy")
except:
    import numpy as xp
    print("using numpy")



#decorators.py
import time
from numba import njit, jit, objmode
from cupy_or_numpy import xp

results = {}
tree = {'stack':['main'], 'main':set()}

def wrapper_objm_start(f):
    start = time.time()
    tree[ tree['stack'][-1] ].add( f.__name__ )
    tree['stack'] += [ f.__name__ ]
    if f.__name__ not in results:
        tree[f.__name__] = set()
        # print(tree['stack'])
    return start

def wrapper_objm_end(f, start):
    run_time = time.time() - start
    if f.__name__ in results:
        results[f.__name__] += [run_time]
    else:
        results[f.__name__] = [run_time]
    tree['stack'] = tree['stack'][:-1]

def timer(f):
    def wrapper(*args, **kwargs):
        start = wrapper_objm_start(f)
        g = f(*args)
        wrapper_objm_end(f, start)
        return g
    return wrapper

def jit_timer(f):
    jf = njit(f)
    @njit(cache=False)
    def wrapper(*args):
        with objmode(start='float64'):
            start = wrapper_objm_start(f)
        g = jf(*args)
        # g = f(*args)
        with objmode():
            wrapper_objm_end(f, start)
        return g
    return wrapper

def reset_globals():
    global results
    results = {}
    global tree
    tree = {'stack':['main'], 'main':set()}

def print_tree(node, layer):
    for n in node:
        rt = xp.sum(results[n])
        rtr = rt / xp.sum(results['main'])
        print('{0:>6.3f} {1:.03f}'.format( rt, rtr ), '-|-'*layer, n)
        print_tree(tree[n], layer+1)

def profile_results():
    print(results)
    print(tree)
    l = []
    for k in results:
        a = xp.asarray(results[k])
        # l += [[k+' '*(17-len(k)), xp.sum(a[1:])]]
        l += [[k+' '*(17-len(k)), xp.sum(a)]]
    l = sorted(l, key=lambda x: x[1])
    # for i in range(len(l)):
    #     print(  '{:.6f}'.format( l[i][1] ), l[i][0] )
        # print( l[i][0], "{:.6f}".format( l[i][1] ) )
    print_tree(tree['main'], 0)

Since DE_dEdt_njit calls dEdt_njit it seems that I can’t apply the decorator outside of this file.

#dEdt
from decorators import timer, jit_timer, reset_globals, profile_results
from cupy_or_numpy import xp

def dEdt_njit(tD, tO, i, T):
    sum1 = 0.0
    for j in range(tO.size):
        # if tO[j] + tO[i] > 0:
        n = tO[j] * ( ( tO[j] - tO[i] ) - ( tO[i] / T ) * ( tO[j] + tO[i] ) )
        d = ( (tO[j] + tO[i]) ** 3.0 )
        e = xp.exp( -( tO[j] + tO[i] ) / T)
        sum1 += (n / d) * e

    sum2 = 0.0
    for j in range(tD.size):
        # if tD[j] + tO[i] > 0:
        n = tD[j] * ( ( tD[j] - tO[i] ) - ( tO[i] / T ) * ( tD[j] + tO[i] ) )
        d = ( (tD[j] + tO[i]) ** 3.0 )
        e = xp.exp( -( tD[j] + tO[i] ) / T)
        sum2 += (n / d) * e
    return 2 * (sum1 - sum2)

def DE_dEdt_njit(tD, tO, T):
    r = xp.zeros(tO.size)
    for i in xp.arange(tO.size):
        r[i] = -dEdt_njit(tD, tO, i, T)
    return r

this is not exactly how I do it, but a simplified version of it with the main ideas

# decorators.py
USE_TIMER = True

from numba import njit
results = {}

def jit_timer():
...

def get_decorator():
    if USE_TIMER:
        return jit_timer
    else:
        return njit

# calculations.py
from decorators import get_decorator

@get_decorator()
def foo(...):
....
1 Like

Thanks for the tip! Yes, this seems like a reasonable way to do it. I think you meant @get_decorator()?

yes, you’re right. fixed!

This works great, but prevents jit-caching because of the ctypes. Is there some way to work around that?

NumbaWarning: Cannot cache compiled function “func1” as it uses dynamic globals (such as ctypes pointers and large global arrays)

Hi @nelson2005

This should do the job:

from numba import njit, objmode, extending, types
from numba.core.cgutils import get_or_insert_function
import numpy as np
import ctypes
import time
from llvmlite import ir

CLOCK_MONOTONIC = 0x1
clock_gettime_proto = ctypes.CFUNCTYPE(ctypes.c_int, ctypes.c_int,
                                       ctypes.POINTER(ctypes.c_long))
pybind = ctypes.CDLL(None)

clock_gettime_addr = pybind.clock_gettime
clock_gettime_fn_ptr = clock_gettime_proto(clock_gettime_addr)

@extending.intrinsic
def clock_gettime(typingctx, clockid_t , timespec):
    def codegen(context, builder, sig, args):
        fnty = ir.FunctionType(
            ir.IntType(64), 
            (
                context.get_value_type(sig.args[0]), 
                context.get_value_type(sig.args[1]),
            )
        )
        fn = get_or_insert_function(builder.module, fnty, "clock_gettime")
        return builder.call(fn, args)

    sig = types.int64(clockid_t , timespec)
    return sig, codegen

@njit(cache=True)
def timenow():
    timespec = np.zeros(2, dtype=np.int64)
    clock_gettime(CLOCK_MONOTONIC, timespec.ctypes)
    ts = timespec[0]
    tns = timespec[1]
    return np.float64(ts) + 1e-9 * np.float64(tns)

@njit(cache=True)
def pointless_delay(seconds):
    with objmode():
        s = time.time()
        e = 0
        while (e < seconds):
            e = time.time() - s

@njit(cache=True)
def do_stuff(n):
    t0 = timenow()
    pointless_delay(n)
    print("Elapsed", timenow() - t0)

do_stuff(1)
do_stuff(2)
do_stuff(3.21)
1 Like

Perfect, thanks! I think you maybe posted something along these lines recently but I didn’t find it after about a half hour of searching. :frowning: