Optimization of Numba Implementation of an Alternating Projection Algorithm

Dear numba community members,

For one of my OSS projects - PyFixest - I have implemented an alternating projection algorithm via numba. Unfortunately, I cannot seem to beat the (fully vectorized) numpy reference implementation in the PyHDFE package. I’d be very grateful for any hints on how to further optimize my code.

First some context: PyFixest implements linear regression estimators for models with high-dimensional fixed effects. To circumvent the problem of having to invert a very large design matrix of dimension N x k, it instead demeans both design matrix and dependent variable prior to fitting the model on demeaned data via ordinarly least squares (based on the Frisch-Waughn-Lovell Theorem). The demeaning occurs via an “alternating projection algorithm”, which I have attached below:

import numpy as np
from numba import njit, prange

@njit(parallel = True, cache = False, fastmath = False)
def demean(cx, flist, weights, tol = 1e-10, maxiter = 2000):

    '''
    Demean a Matrix cx by fixed effects in flist.
    The fixed effects are weighted by weights. Convervence tolerance
    is set to 1e-08 for the sum of absolute differences.
    Args:
        cx: Matrix to be demeaned
        flist: Matrix of fixed effects
        weights: Weights for fixed effects
        tol: Convergence tolerance. 1e-08 by default.
    Returns
        res: Demeaned matrix of dimension cx.shape
    '''
    N = cx.shape[0]
    fixef_vars = flist.shape[1]
    K = cx.shape[1]

    res = np.zeros((N,K))

    # loop over all variables to demean, in parallel 
    for k in prange(K):

        cxk = cx[:,k]#.copy()
        oldxk = cxk - 1

        converged = False
        for _ in range(maxiter):
           
            for i in range(fixef_vars):
                fmat = flist[:,i]
                weighted_ave = _ave3(cxk, fmat, weights)
                cxk -= weighted_ave

            if np.sum(np.abs(cxk - oldxk)) < tol:
                converged = True
                break

            # update 
            oldxk = cxk.copy()



        res[:,k] = cxk

    return res

@njit
def _ave3(x, f, w):

    N = len(x)

    wx_dict = {}
    w_dict = {}

    # Compute weighted sums using a dictionary
    for i in prange(N):
        j = f[i]
        if j in wx_dict:
            wx_dict[j] += w[i] * x[i]
        else:
            wx_dict[j] = w[i] * x[i]

        if j in w_dict:
            w_dict[j] += w[i]
        else:
            w_dict[j] = w[i]

    # Convert the dictionaries to arrays
    wx = np.zeros_like(f, dtype=x.dtype)
    w = np.zeros_like(f, dtype=w.dtype)

    for i in range(N):
        j = f[i]
        wx[i] = wx_dict[j]
        w[i] = w_dict[j]

    # Compute the average
    wxw_long = wx / w

    return wxw_long

To conclude, here is a benchmark against PyHDFE:

%load_ext autoreload
%autoreload 2

import numpy as np
import time

np.random.seed(1238)
N = 1_000_000
x = np.random.normal(0, 1, 4*N).reshape((N,4))
f1 = np.random.choice(list(range(1000)), N).reshape((N,1))
f2 = np.random.choice(list(range(1000)), N).reshape((N,1))

flist = np.concatenate((f1, f2), axis = 1)
weights = np.ones(N)

import pyhdfe
start_time = time.time()
algorithm = pyhdfe.create(flist)
res_pyhdfe = algorithm.residualize(x)
end_time = time.time()
print(end_time - start_time)
# 2.168398380279541


start_time = time.time()
algorithm = pyhdfe.create(flist)
res_pyfixest = demean(x, flist, weights, tol = 1e-08)
# Calculate the execution time
end_time = time.time()
print(end_time - start_time)
# 3.359426736831665

np.allclose(res_pyhdfe , res_pyfixest)
# True

Any tips on how I could further optimize my implementation?

Best, Alex

Hey @s3alfisc ,
you are overwriting your input matrix cx. After executing function demean “cx” equals “res”. Is that a feature or a bug?
I could not find much to improve.
If you remove some redundant lines of code you gain about 20% execution time.
I hope you will find more gains.

@njit(parallel=True, cache=False, fastmath=False)
def demean(cx, flist, weights, tol=1e-10, maxiter=2000):
    N, K = cx.shape
    fixef_vars = flist.shape[1]
    for k in prange(K):
        cxk = cx[:, k]
        oldxk = cxk - 1
        for _ in range(maxiter):
            for i in range(fixef_vars):
                cxk -= _ave3(cxk, flist[:, i], weights)
            if np.sum(np.abs(cxk - oldxk)) < tol:
                break
            oldxk[:] = cxk

@njit(parallel=False, cache=False, fastmath=False)
def _ave3(x, f, w):
    N = len(x)
    w_dict = Dict.empty(key_type=types.int64, value_type=types.float64,)
    wx_dict = Dict.empty(key_type=types.int64, value_type=types.float64,)
    for i in prange(N):
        j = f[i]
        if j in wx_dict:
            wx_dict[j] += w[i] * x[i]
            w_dict[j] += w[i]
        else:
            wx_dict[j] = w[i] * x[i]
            w_dict[j] = w[i]
    res = np.empty_like(x)
    for i in prange(N):
        j = f[i]
        res[i] = wx_dict[j] / w_dict[j]
    return res

# original: 6.06292986869812
# adjusted: 4.800865411758423
# reduction: -21.0%
# Are results close? True

Hi @Oyibo , thanks so much for replying! As it happens, @sschaer reached out to me and even contributed to my package. His implementation lead to around 4-5x speed gains =) unfortunately I cannot post a link here, but you can find his implementation in github / s3alfisc / pyfixest / demean .py in case you are curious =)

Hey @s3alfisc ,
that’s awesome. @sschaer is a monster :wink:
Congratulations!!!