Where to optimize?

Intro

I recently implemented a fft algorithm in python so that I can take advantage of Numba. I have provided an end-to-end script below. I have tested its performance for a time series data with length 2^2, 2^3, …, 2^27, and compared them with MATALB fft, using the same system and same input. I think the performance is good. However, I can see that MATLAB fft performs waaay better when the length of time series is short. I am looking for ways to optimize my python code further and get faster running time when time series is short. In what follows, I’ve shared the code and the show its performance. You can see my questions at the bottom of this post.

fft implementation

Note: the implemented function fft is equivalent to scipy.fft.rfft

import math
import numpy as np
from numba import prange, njit


@njit(fastmath=True)  
def _fft0(n, s, eo, x, y): 
    """
    A recursive function that is used as part of fft algorithm
    
    n : int
    s : int
    eo: bool
    x : numpy.array 1D
    y : numpy.array 1D
    """
    if n == 2:
        if eo:
            z = y
        else:
            z = x
        
        for i in range(s):
            j = i + s
            a = x[i]
            b = x[j]
            z[i] = a + b
            z[j] = a - b
            
    elif n >= 4:
        m = n // 2
        sm = s * m
    
        theta = math.pi / m
        c = math.cos(theta) - 1j * math.sin(theta)
        
        twiddle_factor = 1.0
        for p in range(m):
            sp = s * p
            two_sp = 2 * sp
            for q in range(s):
                i = sp + q
                j = i + sm
                
                k = two_sp + q
                y[k] = x[i] + x[j]
                y[k + s] = (x[i] - x[j]) * twiddle_factor
        
            twiddle_factor = twiddle_factor * c
        
        _fft0(m, 2*s, not eo, y, x)
        
    else:
        pass
    

@njit(fastmath=True, parallel=True)
def _sixstep_fft(logtwo_n, x, y):
    N = 2 ** logtwo_n
    n = 2 ** int(logtwo_n // 2)   # basically, np.sqrt(N)
    
    for k in prange(n):
        for p in range(k + 1, n):
            i = k + p * n
            j = p + k * n
            x[i], x[j] = x[j], x[i]
        
    for p in prange(n):
        start = p * n
        _fft0(n, 1, False, x[start:], y[start:])
    
    theta_init = 2 * math.pi / N
    n_plus_1 = n + 1
    for p in prange(n):
        theta0 = theta_init * p
        ppn = p * n_plus_1
        
        c = math.cos(theta0) - 1j * math.sin(theta0)
        w = math.cos(theta0 * p) - 1j * math.sin(theta0 * p)
        for alpha in range(0, n - p):
            i = ppn + alpha
    
            if alpha == 0:
                x[i] = x[i] * w
            else:
                j = ppn + alpha * n
                x[j], x[i] = x[i] * w, x[j] * w
                
            w = w * c
        
    for k in prange(n):
        start = k * n
        _fft0(n, 1, False, x[start:], y[start:])
        
    for k in prange(n):
        kn = k * n
        for p in range(k + 1, n):
            i = k + p * n
            j = p + kn
            x[i], x[j] = x[j], x[i]
            

@njit(fastmath=True, parallel=True)
def _eightstep_fft(logtwo_n, x, y):
    n = 2 ** logtwo_n
    m = int (n // 2)
    
    theta0 = 2 * math.pi / n
    for i in prange(m):
        theta = i * theta0
        wp = math.cos(theta) - 1j * math.sin(theta)
        
        j = i + m
        y[i] = x[i] + x[j]
        y[j] = (x[i] - x[j]) * wp

    _sixstep_fft(logtwo_n - 1, y, x)
    _sixstep_fft(logtwo_n - 1, y[m:], x[m:])

    for p in prange(m):
        two_p = 2 * p
        x[two_p] = y[p]
        x[two_p + 1] = y[p + m]
        
    return
        

@njit(fastmath=True)  
def _compute_fft(x, y):
    # only when len(x) is a power of two, and it is >= 2.
    n = len(x)
    logtwo_n = int(np.log2(n))
    
    if logtwo_n == 1: 
        _fft0(n, 1, False, x, y)
    elif logtwo_n % 2 == 0:
        _sixstep_fft(logtwo_n, x, y)
    else:
        _eightstep_fft(logtwo_n, x, y)
        
    return


@njit(fastmath=True, parallel=True)  
def _fft(T, y):
    n = len(T)
    half_n = int(n // 2)
    
    x = T[::2] + 1j * T[1::2]
    _compute_fft(x, y[:half_n])
    
    y[0] = x[0].real + x[0].imag
    y[n // 4] = x[n // 4].conjugate()
    y[half_n] = x[0].real - x[0].imag
    
    theta0 = math.pi / half_n
    for k in prange(1, n // 4):
        c = x[half_n - k].conjugate()
        theta = theta0 * k
        val = 0.5 * (x[k] - c) * (1 + math.sin(theta) + 1j * math.cos(theta))
        y[k] = x[k] - val
        y[half_n - k] = x[half_n - k] + val.conjugate()


def fft(T):
    n = len(T)
    n_rfft = n // 2 + 1
    y = np.empty(n_rfft, dtype=np.complex_)
    
    _fft(T, y)
 
    return y

Performance Python vs MATLAB

As observed, the Python code with Numba is much much slower than MATLAB for 2^32^5.

In fact, if I try to plot the performance of each, i.e. Python and MATLAB, individually, I will see:

Two questions here:
(1) Why do I see a sawtooth-like plot in python? The peak points are when log-two of length is an even number. These are the cases that end up with calling the function _eightstep_fft.

(2) How to find the part whose optimization can further improve the performance of the python fft code?

1 Like

Hi @NimaSarajpoor

Your code performs poorly for very small arrays because of the parallelism which causes overhead.

I assume the sawtooth-like pattern you observe comes from the fact that your implementation is broken for cases where “log-two of length” is even. Run for example this:

import numpy as np 
import matplotlib.pyplot as plt 

def fft(T): 
    ... 

for i in range(5, 18):
    x = np.linspace(0, 10*np.pi, num=2**i)
    y = np.sin(x)

    plt.figure()
    plt.title(str(i))
    plt.plot(fft(y))
    plt.plot(np.fft.rfft(y), ls="--")
    plt.xlim(0, 12)
    plt.show()

PS: There is support for scipy.fft.rfft in Numba compiled code: Rocket-FFT — a Numba extension supporting numpy.fft and scipy.fft
However, it may be a bit slower for 1D transforms since multi-threading is only supported over multiple dimensions.

A possible add-on to @sschaer 's excellent answer:
If you haven’t “warmed up” your code by calling each of your njit-ted functions before timing them, you might see slowdowns like these. Warming up pulls your initial compilation times out of your profiling totals.

You’re probably already doing this, but calling it out just in case. I’ve been bitten by the just-in-time compiling penalty before.

Brandon

1 Like

Thank you both @sschaer and @Brohrer for sharing your input! I think I should have been clearer on what things I already noticed / tried. That was my bad! Newbie Alert :smiley: I am interested in knowing if there is any way to change the code to make the compiled code optimal. So, there might be some small details / tricks that one can use to give the performance a boost.

@sschaer

Your code performs poorly for very small arrays because of the parallelism which causes overhead.

Right! So, one way is to avoid parallel computing when the array is short but I prefer to avoid trying it for now as there might be another way to optimize the compiled code without increasing the code complexity .

I assume the sawtooth-like pattern you observe comes from the fact that your implementation is broken for cases where “log-two of length” is even.

You correctly detected it! I should have mentioned it in my post. Cannot understand “what” is broken there though.

PS: There is support for scipy.fft.rfft in Numba compiled code: Rocket-FFT — a Numba extension supporting numpy.fft and scipy.fft
However, it may be a bit slower for 1D transforms since multi-threading is only supported over multiple dimensions.

First of all, kudos on the package!! I found it when I was searching for ways to leverage Numba. What I am looking for is to run fft for 1D array in parallel manner. I can see the parameter workers in scipy.fft.ftt; but, as you said, it does not affect the performance of computing fft for 1D array.

@Brohrer
Thanks Brandon! I did a dummy run and throw it away before timing the code execution. (I should have mentioned it in my post :smiley: My bad!)

Btw, regarding:

I’ve been bitten by the just-in-time compiling penalty before.

You are not alone! It happened to me before too :slight_smile:


I definitely need to revise my question, making the code shorter, and making the title clearer.

1 Like