I am implementing a direct cross-correlation in python using numpy to compare it with a fft based cross correlation performance wise. Since it involve nested loops I decided to try numba and see if I can get any speed up. My first concern is that the jitted compilation is slower than the raw implementation. I was expecting a speedup with loop lifting but did not seem to work out. Here is the minimal code:
import numba as nb
from numba import float32, int32
import numpy as np
def cross_correlation(image, kernel, stride=1, padding=0):
# Compute the sizes of image and kernel
image_size = image.shape[-1]
kernel_size = kernel.shape[-1]
output_size = (image_size + 2 * padding -kernel_size) // stride + 1
image = image[:,None,...]
result = np.zeros((image.shape[0], kernel.shape[0], output_size, output_size), dtype=np.float32)
kernel = kernel[None,...]
for i in range(0,output_size, stride):
for j in range(0,output_size, stride):
r = image[...,i:i+kernel.shape[-1], j:j+kernel.shape[-1]] * kernel
#result[:,:,i,j] = r.sum(axis=(-3, -2,-1))
return result
@nb.njit(float32[:,:,:,:](float32[:,:,:,:], float32[:,:,:,:], int32, int32))
def jitted_cross_correlation(image, kernel, stride=1, padding=0):
# Compute the sizes of image and kernel
image_size = image.shape[-1]
kernel_size = kernel.shape[-1]
output_size = (image_size + 2 * padding -kernel_size) // stride + 1
image = image[:,None,...]
result = np.zeros((image.shape[0], kernel.shape[0], output_size, output_size), dtype=np.float32)
kernel = kernel[None,...]
for i in range(0,output_size, stride):
for j in range(0,output_size, stride):
r = image[...,i:i+kernel.shape[-1], j:j+kernel.shape[-1]] * kernel
#result[:,:,i,j] = r.sum(axis=(-3, -2,-1))
return result
if __name__ == "__main__":
import time
x = np.random.randn(150,1,28,28).astype(np.float32)
k = np.random.randn(16,1,3,3).astype(np.float32)
# warmup
jitted_cross_correlation(x,k, stride=1, padding=0)
t = time.time()
for _ in range(100):
jitted_cross_correlation(x,k, stride=1, padding=0)
print("jitted_cross_correlation : ",time.time()-t) # approx 5.2 s
t = time.time()
for _ in range(100):
cross_correlation(x,k, stride=1, padding=0)
print("cross_correlation : ",time.time()-t) # approx 4.7 s
Is this to be expected ? My second concern is that numpy.sum is not supported by numba, resulting in compilation error even if nopython=False.(“numba.core.errors.CompilerError: Failed in object mode pipeline (step: remove phis nodes).Illegal IR, del found at: del $120for_iter.1”). Why is that ?

