To implement some functions (eg. einsum), the esiest way may be to generate the code at compile time. For this example I reimplemented np.sum (axis can only be literals or a tuple of literals).
- Are there better ways to generate code at compile time (eg. without using exec)?
- Why is Numba so slow when using the axis parameter?
Implementation
import numpy as np
import numba as nb
from numba.extending import overload
def gen_sum_impl(arr_dim,c_contiguous,f_contiguous,axis=None):
s="import numpy as np\n"
s+="def sum_impl(arr,axis=None):\n"
#Check for axis <0
last_axis_contract=False
if axis is not None:
if isinstance(axis, tuple):
axis=np.array(axis)
else:
axis=np.array(axis)
axis[axis<0]+=arr_dim
#handle fortran ordered arrays
if f_contiguous:
s+=" arr=arr.T\n"
if axis is not None:
axis=(arr_dim-1)-axis
alloc_axis = np.setdiff1d(np.arange(arr_dim),axis)
#Check for illegal axis and illegal axis
#and if sum is performed over last axis
if axis is not None:
last_axis_contract=False
if np.max(axis)>=arr_dim:
print("Axis is higher than array dimension!")
axis=axis.tolist()
if isinstance(axis, int):
if axis==(arr_dim-1):
last_axis_contract=True
else:
if (arr_dim-1) in axis:
last_axis_contract=True
#preallocating output
if axis is None:
s+=" res=0\n"
else:
s_alloc=str().join("arr.shape["+str(i)+"]," for i in alloc_axis)
s+=" res=np.zeros(("+s_alloc+"),dtype=arr.dtype)\n"
#avoid nested loops on contiguous arrays
if axis is None and (c_contiguous or f_contiguous):
s+=" arr=arr.reshape(-1)\n"
s+=" for i_0 in range(arr.shape[0]):\n"
s+=" res+=arr[i_0]\n"
#nested loops
else:
s_lhs=str().join("i_"+str(i)+"," for i in alloc_axis)[:-1]
s_rhs=str().join("i_"+str(i)+"," for i in np.arange(arr_dim))[:-1]
#Special implementation if last dim is summed up (performance)
if last_axis_contract:
for i in range(arr_dim-1):
s+=" "*(i+1)+"for i_"+str(i)+" in range(arr.shape["+str(i)+"]):\n"
s+=" "*(arr_dim)+"acc=0\n"
s+=" "*(arr_dim)+"for i_"+str(arr_dim-1)+" in range(arr.shape["+str(arr_dim-1)+"]):\n"
s+=" "*(arr_dim+1)+"acc += arr["+s_rhs+"]\n"
s+=" "*(arr_dim)+"res["+s_lhs+"] += acc\n"
else:
for i in range(arr_dim):
s+=" "*(i+1)+"for i_"+str(i)+" in range(arr.shape["+str(i)+"]):\n"
s+=" "*(arr_dim+1)+"res["+s_lhs+"] += arr["+s_rhs+"]\n"
if f_contiguous and axis is not None:
s+=" return res.T"
else:
s+=" return res"
return s
def exec_func(func_str,function_name=None):
dict_func = {}
exec(func_str, dict_func)
if function_name==None:
function_name=list(dict_func.keys())[-1]
return dict_func[function_name]
def sum_2(arr,axis=None):
return np.sum(arr,axis)
@overload(sum_2,jit_options={'fastmath':True},strict=True)
def sum(arr, axis=None):
"""
axis can only be zero, a literal or a tuple of int literals
"""
axis_TMP=None
if axis is not None and not isinstance(axis,nb.core.types.misc.NoneType):
if isinstance(axis,nb.core.types.containers.Tuple):
axis_TMP=[ax.literal_value for ax in axis]
else:
axis_TMP=axis.literal_value
func_str=gen_sum_impl(arr.ndim,arr.is_c_contig,arr.is_f_contig,axis=axis_TMP)
func=exec_func(func_str,function_name='sum_impl')
return func
Timings
@nb.njit(fastmath=True)
def nb_sum(arr):
return np.sum(arr)
@nb.njit(fastmath=True)
def nb_sum_2(arr):
return sum_2(arr,axis=(2,3)
#arr=np.random.rand(80,80,80,80)
#axis=-1
%timeit np.sum(arr,axis=-1)
#21.6 ms ± 301 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nb_sum(arr)
#367 ms ± 8.55 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nb_sum_2(arr)
#16.9 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
np.allclose(np.sum(arr,axis=(-1)),nb_sum_2(arr))
#True
#arr=np.random.rand(80,80,80,80).T
#axis=-1
%timeit np.sum(arr,axis=-1)
#23.5 ms ± 435 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit nb_sum(arr)
#139 ms ± 333 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit nb_sum_2(arr)
#22 ms ± 232 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
np.allclose(np.sum(arr,axis=(-1)),nb_sum_2(arr))
#True
#arr=np.random.rand(80,80,80,80)
#axis=(2,3)
%timeit np.sum(arr,axis=(2,3))
#37.2 ms ± 303 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#%timeit nb_sum(arr)
#not supported
%timeit nb_sum_2(arr)
#15.3 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
np.allclose(np.sum(arr,axis=(2,3)),nb_sum_2(arr))
#True