Create log message on Numba compilation / Find out if given arguments lead to compilation

Hi. I would like to print out a log message whenever Numba jit compiles for an unknown signature. I have looked a bit into the Numba code. I think I would need to create a custom decorator that checks the signature of the input and whether will Numba recompile for the input or not. How can I create a Numba signature from a given input and check if the CPUDispatcher will recompile for it? Is there an easier way to emit log messages for Numba?

hi! Do you want to know in advance whether a given input will trigger a new compilation, or would it be ok to know after the fact that a new compilation was performed?
If the latter works, Dispatcher objects have a signatures property, and len(signatures) will tell you for how many signatures the function has been compiled. If you could create a decorator that compares len(signatures) before and after the execution.

1 Like

You could do some hackish monkey patching

In [1]: import numba                                                                                                                                         

In [2]: oldcomp = numba.core.registry.CPUDispatcher.compile                                                                                                  

In [3]: def newcomp(*args, **kwargs): 
   ...:     print("I am compiling") 
   ...:     return oldcomp(*args, **kwargs) 
   ...:                                                                                                                                                      

In [4]: numba.core.registry.CPUDispatcher.compile = newcomp                                                                                                  

In [5]: @numba.njit 
   ...: def f(x): 
   ...:     return x 
   ...:                                                                                                                                                      

In [6]: f(1)                                                                                                                                                 
I am compiling
Out[6]: 1

In [7]: f(1.0)                                                                                                                                               
I am compiling
Out[7]: 1.0

In [8]: f(1.0)                                                                                                                                               
Out[8]: 1.0
1 Like

As suggested by @luk-f-a finding out if something got compiled is relatively easy, e.g.

from numba import njit
import numpy as np

def logging_jit(func):
    def inner(*args, **kwargs):
        origsigs = set(func.signatures)
        result = func(*args, **kwargs)
        newsigs = set(func.signatures)
        if newsigs != origsigs:
            new = (newsigs ^ origsigs).pop()
             # PUT YOUR LOGGER HERE!
            print("Numba compiled for signature: {}".format(new))
        return result
    return inner

@logging_jit
@njit
def foo(a):
    return a + 1

print(foo(4)) # first compile and run for int
print(foo(5)) # second is just a run, int sig is cached
print(foo(6.7)) # third is a compile and run for float

finding out if something is going to be compiled is trickier but there are ways of doing that, it’ll involve producing Numba type signatures from the arguments and then comparing them to what’s already in the dispatcher’s cache.

1 Like

Great. Thanks for all 3 replies. It is perfectly fine for me to emit a message after compilation. I did not consider to simply check the number of signatures. Thanks for these simple solutions and examples!

Timo