JItted cross correlation slower than raw python

I am implementing a direct cross-correlation in python using numpy to compare it with a fft based cross correlation performance wise. Since it involve nested loops I decided to try numba and see if I can get any speed up. My first concern is that the jitted compilation is slower than the raw implementation. I was expecting a speedup with loop lifting but did not seem to work out. Here is the minimal code:

import numba as nb
from numba import float32, int32
import numpy as np


def cross_correlation(image, kernel, stride=1, padding=0):
	# Compute the sizes of image and kernel
	image_size = image.shape[-1]
	kernel_size = kernel.shape[-1]

	output_size = (image_size + 2 * padding -kernel_size) // stride + 1
	
	image = image[:,None,...]

	result = np.zeros((image.shape[0], kernel.shape[0], output_size, output_size), dtype=np.float32)
	kernel = kernel[None,...]
	for i in range(0,output_size, stride):
		for j in range(0,output_size, stride):
			r = image[...,i:i+kernel.shape[-1], j:j+kernel.shape[-1]] * kernel
			#result[:,:,i,j] = r.sum(axis=(-3, -2,-1))
	return result


@nb.njit(float32[:,:,:,:](float32[:,:,:,:], float32[:,:,:,:], int32, int32))
def jitted_cross_correlation(image, kernel, stride=1, padding=0):
	# Compute the sizes of image and kernel
	image_size = image.shape[-1]
	kernel_size = kernel.shape[-1]

	output_size = (image_size + 2 * padding -kernel_size) // stride + 1

	image = image[:,None,...]

	result = np.zeros((image.shape[0], kernel.shape[0], output_size, output_size), dtype=np.float32)
	kernel = kernel[None,...]
	for i in range(0,output_size, stride):
		for j in range(0,output_size, stride):
			r = image[...,i:i+kernel.shape[-1], j:j+kernel.shape[-1]] * kernel
			#result[:,:,i,j] = r.sum(axis=(-3, -2,-1))
	return result

if __name__ == "__main__":
	import time
	x = np.random.randn(150,1,28,28).astype(np.float32)
	k = np.random.randn(16,1,3,3).astype(np.float32)

	# warmup
	jitted_cross_correlation(x,k, stride=1, padding=0)

	t = time.time()
	for _ in range(100):
		jitted_cross_correlation(x,k, stride=1, padding=0)
	print("jitted_cross_correlation : ",time.time()-t) # approx 5.2 s

	t = time.time()
	for _ in range(100):
		cross_correlation(x,k, stride=1, padding=0)
	print("cross_correlation : ",time.time()-t) # approx 4.7 s

Is this to be expected ? My second concern is that numpy.sum is not supported by numba, resulting in compilation error even if nopython=False.(“numba.core.errors.CompilerError: Failed in object mode pipeline (step: remove phis nodes).Illegal IR, del found at: del $120for_iter.1”). Why is that ?

@sabeauss I tried the example, but for me the jitted variant is faster?! Or did I miss something?

Note that you can use py_func on any @njit decorated function to access the original Python variant. Perhaps it is a useful hint so you don’t need to duplicate implementations when benchmarking? As for the sum – the error you are seeing, stems from the lack of support for the axis kwarg. Maybe you can implement the sum with simple for-loops and @njit that too? It would probably be an adequate workaround for the time being.

import numba as nb
from numba import float32, int32
import numpy as np

@nb.njit(float32[:,:,:,:](float32[:,:,:,:], float32[:,:,:,:], int32, int32))
def cross_correlation(image, kernel, stride=1, padding=0):
	# Compute the sizes of image and kernel
	image_size = image.shape[-1]
	kernel_size = kernel.shape[-1]

	output_size = (image_size + 2 * padding -kernel_size) // stride + 1

	image = image[:,None,...]

	result = np.zeros((image.shape[0], kernel.shape[0], output_size, output_size), dtype=np.float32)
	kernel = kernel[None,...]
	for i in range(0,output_size, stride):
		for j in range(0,output_size, stride):
			r = image[...,i:i+kernel.shape[-1], j:j+kernel.shape[-1]] * kernel
			#result[:,:,i,j] = r.sum(axis=(-3, -2,-1))  sum doesn't yet accept the axis kwarg
	return result
x = np.random.randn(150,1,28,28).astype(np.float32)
k = np.random.randn(16,1,3,3).astype(np.float32)
%timeit cross_correlation(x,k, stride=1, padding=0)
34.3 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit cross_correlation.py_func(x,k, stride=1, padding=0)
79.4 ms ± 129 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Note that by doing float[:] you are losing some optimizations, since you’re forcing it to assume the array may not be contiguous. So omitting type annotations in the signature (and just making inputs are in the right type) might help.

https://numba.readthedocs.io/en/stable/reference/types.html#arrays

@esc Thank you for taking the time to answer and for the tips. I must be the one missing something here, maybe on how to properly time executions. The following snippet of code still show the jitted function to be slower

t1 = timeit.timeit("cross_correlation(x,k, stride=1, padding=0)", globals=globals(), number=100)

t2 = timeit.timeit("cross_correlation.py_func(x,k, stride=1, padding=0)", globals=globals(), number=100)

even with a warmup. Whats your thought on that ?

I tried to run the command you posted but still fail to reproduce this. The jitted variant is about twice as fast as the pure python variant.

Maybe this is because we have different hardware. I am on a MacBook Pro with an M1 processor.

Ok, well that’s weird. Tried it in both google colab and my computer (AMD ryzen 9). Timeit sometimes give a slight advantage to the jitted version or the other way around. As you say this is likely due to difference in hardware but this is a bit frustrating.

I guess there is no fixing this for now. Thank you for your time !

Yes, it does appear quite mysterious to me, I’m not sure what to make of it and it’s certainly not what I was expecting. I suppose it would be possible to do some detective work using:

https://numba.pydata.org/numba-doc/dev/reference/jit-compilation.html#Dispatcher.inspect_asm

and maybe:

to better understand where the bottleneck lies, but it’ll be an investigation.

Is any other reader of the forums perhaps able to try the above benchmarks also and report back some timings? Maybe a crowdsourced benchmark could give us some insights here?

The code performance could be limited by the memory bandwidth (a.k.a memory bound). i.e The CPU cores spend more time waiting for memory requests to be fulfilled than the actual computation. That can explain why there’s no speedup.

To test if this is true, make a roofline graph (see Roofline model - Wikipedia) by benchmarking the code performance across different input sizes. For simplicity, x-axis can be input size and y-axis can be input_size/second (estimate for FLOPS).

What you can expect is that from small to medium input sizes, the FLOPS increase very quickly. Then, the FLOPS peak at the “roofline” to indicate memory bound.

To shift the roofline further to the right (bigger input sizes), you will need to modify the algorithm to better take advantage of the machine memory hierarchy (e.g. cache locality, prefetch, etc).