What is the best way to generate an implementation of a function at compile time?

To implement some functions (eg. einsum), the esiest way may be to generate the code at compile time. For this example I reimplemented np.sum (axis can only be literals or a tuple of literals).

  • Are there better ways to generate code at compile time (eg. without using exec)?
  • Why is Numba so slow when using the axis parameter?

Implementation

import numpy as np
import numba as nb
from numba.extending import overload

def gen_sum_impl(arr_dim,c_contiguous,f_contiguous,axis=None):
    s="import numpy as np\n"
    s+="def sum_impl(arr,axis=None):\n"
    
    #Check for axis <0
    last_axis_contract=False
    if axis is not None:
        if isinstance(axis, tuple):
            axis=np.array(axis)
        else: 
            axis=np.array(axis)
            axis[axis<0]+=arr_dim
    
    #handle fortran ordered arrays
    if f_contiguous:
        s+="    arr=arr.T\n"
        if axis is not None:
            axis=(arr_dim-1)-axis
    alloc_axis = np.setdiff1d(np.arange(arr_dim),axis)
    
    #Check for illegal axis and illegal axis 
    #and if sum is performed over last axis
    if axis is not None:
        last_axis_contract=False
        if np.max(axis)>=arr_dim:
            print("Axis is higher than array dimension!")
        
        axis=axis.tolist()
        if isinstance(axis, int):
            if axis==(arr_dim-1):
                last_axis_contract=True
        else:
            if (arr_dim-1) in axis:
                last_axis_contract=True
    
    #preallocating output
    if axis is None:
        s+="    res=0\n"
    else:
        s_alloc=str().join("arr.shape["+str(i)+"]," for i in alloc_axis)
        s+="    res=np.zeros(("+s_alloc+"),dtype=arr.dtype)\n"
    
    #avoid nested loops on contiguous arrays
    if axis is None and (c_contiguous or f_contiguous):
        s+="    arr=arr.reshape(-1)\n"
        s+="    for i_0 in range(arr.shape[0]):\n"
        s+="        res+=arr[i_0]\n"
    
    #nested loops
    else:
        s_lhs=str().join("i_"+str(i)+"," for i in alloc_axis)[:-1]
        s_rhs=str().join("i_"+str(i)+"," for i in np.arange(arr_dim))[:-1]
        
        #Special implementation if last dim is summed up (performance)
        if last_axis_contract:
            for i in range(arr_dim-1):
                s+="    "*(i+1)+"for i_"+str(i)+" in range(arr.shape["+str(i)+"]):\n"
            s+="    "*(arr_dim)+"acc=0\n"
            s+="    "*(arr_dim)+"for i_"+str(arr_dim-1)+" in range(arr.shape["+str(arr_dim-1)+"]):\n"
            s+="    "*(arr_dim+1)+"acc += arr["+s_rhs+"]\n"
            s+="    "*(arr_dim)+"res["+s_lhs+"] += acc\n"
        else:
            for i in range(arr_dim):
                s+="    "*(i+1)+"for i_"+str(i)+" in range(arr.shape["+str(i)+"]):\n"

            s+="    "*(arr_dim+1)+"res["+s_lhs+"] += arr["+s_rhs+"]\n"
    
    if f_contiguous and axis is not None:
        s+="    return res.T"
    else:
        s+="    return res"
    return s

def exec_func(func_str,function_name=None):
    dict_func = {}
    exec(func_str, dict_func)
    if function_name==None:
        function_name=list(dict_func.keys())[-1]
    return dict_func[function_name]

def sum_2(arr,axis=None):
    return np.sum(arr,axis)

@overload(sum_2,jit_options={'fastmath':True},strict=True)
def sum(arr, axis=None):
    """
    axis can only be zero, a literal or a tuple of int literals
    """
    axis_TMP=None
    if axis is not None and  not isinstance(axis,nb.core.types.misc.NoneType):
        if isinstance(axis,nb.core.types.containers.Tuple):
            axis_TMP=[ax.literal_value for ax in axis]
        else:
            axis_TMP=axis.literal_value
    func_str=gen_sum_impl(arr.ndim,arr.is_c_contig,arr.is_f_contig,axis=axis_TMP)
    func=exec_func(func_str,function_name='sum_impl')
    return func

Timings

@nb.njit(fastmath=True)
def nb_sum(arr):
    return np.sum(arr)

@nb.njit(fastmath=True)
def nb_sum_2(arr):
    return sum_2(arr,axis=(2,3)

#arr=np.random.rand(80,80,80,80)
#axis=-1
%timeit np.sum(arr,axis=-1)
#21.6 ms ± 301 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nb_sum(arr)
#367 ms ± 8.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nb_sum_2(arr)
#16.9 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
np.allclose(np.sum(arr,axis=(-1)),nb_sum_2(arr))
#True

#arr=np.random.rand(80,80,80,80).T
#axis=-1
%timeit np.sum(arr,axis=-1)
#23.5 ms ± 435 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nb_sum(arr)
#139 ms ± 333 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nb_sum_2(arr)
#22 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
np.allclose(np.sum(arr,axis=(-1)),nb_sum_2(arr))
#True

#arr=np.random.rand(80,80,80,80)
#axis=(2,3)
%timeit np.sum(arr,axis=(2,3))
#37.2 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#%timeit nb_sum(arr)
#not supported
%timeit nb_sum_2(arr)
#15.3 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
np.allclose(np.sum(arr,axis=(2,3)),nb_sum_2(arr))
#True

Although in your case since you’re overloading you’re achieving roughly the same thing.I haven’t 100% read through your code, but something I usually do in these situations is write an njitted function for any behavior that is shared between implementations and then return a dynamically generated function (via python, usually not strings) for each type.

Also if you must generate your code via building strings (sometimes this is the only option, but there are usually others), I often find that this https://realpython.com/python-f-strings/ is much cleaner than concatenating each line.

Other performance tip. If you use np.empty() instead of np.zeros() you can shave off some time. Just be careful that you don’t leave any slots unfilled or they could take unpredictable values.