Intro
I recently implemented a fft algorithm in python so that I can take advantage of Numba. I have provided an end-to-end script below. I have tested its performance for a time series data with length 2^2
, 2^3
, …, 2^27
, and compared them with MATALB fft, using the same system and same input. I think the performance is good. However, I can see that MATLAB fft performs waaay better when the length of time series is short. I am looking for ways to optimize my python code further and get faster running time when time series is short. In what follows, I’ve shared the code and the show its performance. You can see my questions at the bottom of this post.
fft implementation
Note: the implemented function fft is equivalent to scipy.fft.rfft
import math
import numpy as np
from numba import prange, njit
@njit(fastmath=True)
def _fft0(n, s, eo, x, y):
"""
A recursive function that is used as part of fft algorithm
n : int
s : int
eo: bool
x : numpy.array 1D
y : numpy.array 1D
"""
if n == 2:
if eo:
z = y
else:
z = x
for i in range(s):
j = i + s
a = x[i]
b = x[j]
z[i] = a + b
z[j] = a - b
elif n >= 4:
m = n // 2
sm = s * m
theta = math.pi / m
c = math.cos(theta) - 1j * math.sin(theta)
twiddle_factor = 1.0
for p in range(m):
sp = s * p
two_sp = 2 * sp
for q in range(s):
i = sp + q
j = i + sm
k = two_sp + q
y[k] = x[i] + x[j]
y[k + s] = (x[i] - x[j]) * twiddle_factor
twiddle_factor = twiddle_factor * c
_fft0(m, 2*s, not eo, y, x)
else:
pass
@njit(fastmath=True, parallel=True)
def _sixstep_fft(logtwo_n, x, y):
N = 2 ** logtwo_n
n = 2 ** int(logtwo_n // 2) # basically, np.sqrt(N)
for k in prange(n):
for p in range(k + 1, n):
i = k + p * n
j = p + k * n
x[i], x[j] = x[j], x[i]
for p in prange(n):
start = p * n
_fft0(n, 1, False, x[start:], y[start:])
theta_init = 2 * math.pi / N
n_plus_1 = n + 1
for p in prange(n):
theta0 = theta_init * p
ppn = p * n_plus_1
c = math.cos(theta0) - 1j * math.sin(theta0)
w = math.cos(theta0 * p) - 1j * math.sin(theta0 * p)
for alpha in range(0, n - p):
i = ppn + alpha
if alpha == 0:
x[i] = x[i] * w
else:
j = ppn + alpha * n
x[j], x[i] = x[i] * w, x[j] * w
w = w * c
for k in prange(n):
start = k * n
_fft0(n, 1, False, x[start:], y[start:])
for k in prange(n):
kn = k * n
for p in range(k + 1, n):
i = k + p * n
j = p + kn
x[i], x[j] = x[j], x[i]
@njit(fastmath=True, parallel=True)
def _eightstep_fft(logtwo_n, x, y):
n = 2 ** logtwo_n
m = int (n // 2)
theta0 = 2 * math.pi / n
for i in prange(m):
theta = i * theta0
wp = math.cos(theta) - 1j * math.sin(theta)
j = i + m
y[i] = x[i] + x[j]
y[j] = (x[i] - x[j]) * wp
_sixstep_fft(logtwo_n - 1, y, x)
_sixstep_fft(logtwo_n - 1, y[m:], x[m:])
for p in prange(m):
two_p = 2 * p
x[two_p] = y[p]
x[two_p + 1] = y[p + m]
return
@njit(fastmath=True)
def _compute_fft(x, y):
# only when len(x) is a power of two, and it is >= 2.
n = len(x)
logtwo_n = int(np.log2(n))
if logtwo_n == 1:
_fft0(n, 1, False, x, y)
elif logtwo_n % 2 == 0:
_sixstep_fft(logtwo_n, x, y)
else:
_eightstep_fft(logtwo_n, x, y)
return
@njit(fastmath=True, parallel=True)
def _fft(T, y):
n = len(T)
half_n = int(n // 2)
x = T[::2] + 1j * T[1::2]
_compute_fft(x, y[:half_n])
y[0] = x[0].real + x[0].imag
y[n // 4] = x[n // 4].conjugate()
y[half_n] = x[0].real - x[0].imag
theta0 = math.pi / half_n
for k in prange(1, n // 4):
c = x[half_n - k].conjugate()
theta = theta0 * k
val = 0.5 * (x[k] - c) * (1 + math.sin(theta) + 1j * math.cos(theta))
y[k] = x[k] - val
y[half_n - k] = x[half_n - k] + val.conjugate()
def fft(T):
n = len(T)
n_rfft = n // 2 + 1
y = np.empty(n_rfft, dtype=np.complex_)
_fft(T, y)
return y
Performance Python vs MATLAB
As observed, the Python code with Numba is much much slower than MATLAB for 2^3
… 2^5
.
In fact, if I try to plot the performance of each, i.e. Python and MATLAB, individually, I will see:
Two questions here:
(1) Why do I see a sawtooth-like plot in python? The peak points are when log-two of length is an even number. These are the cases that end up with calling the function _eightstep_fft
.
(2) How to find the part whose optimization can further improve the performance of the python fft code?