# Optimization of Numba Implementation of an Alternating Projection Algorithm

Dear `numba` community members,

For one of my OSS projects - `PyFixest` - I have implemented an alternating projection algorithm via `numba`. Unfortunately, I cannot seem to beat the (fully vectorized) numpy reference implementation in the `PyHDFE` package. I’d be very grateful for any hints on how to further optimize my code.

First some context: `PyFixest` implements linear regression estimators for models with high-dimensional fixed effects. To circumvent the problem of having to invert a very large design matrix of dimension N x k, it instead demeans both design matrix and dependent variable prior to fitting the model on demeaned data via ordinarly least squares (based on the Frisch-Waughn-Lovell Theorem). The demeaning occurs via an “alternating projection algorithm”, which I have attached below:

``````import numpy as np
from numba import njit, prange

@njit(parallel = True, cache = False, fastmath = False)
def demean(cx, flist, weights, tol = 1e-10, maxiter = 2000):

'''
Demean a Matrix cx by fixed effects in flist.
The fixed effects are weighted by weights. Convervence tolerance
is set to 1e-08 for the sum of absolute differences.
Args:
cx: Matrix to be demeaned
flist: Matrix of fixed effects
weights: Weights for fixed effects
tol: Convergence tolerance. 1e-08 by default.
Returns
res: Demeaned matrix of dimension cx.shape
'''
N = cx.shape[0]
fixef_vars = flist.shape[1]
K = cx.shape[1]

res = np.zeros((N,K))

# loop over all variables to demean, in parallel
for k in prange(K):

cxk = cx[:,k]#.copy()
oldxk = cxk - 1

converged = False
for _ in range(maxiter):

for i in range(fixef_vars):
fmat = flist[:,i]
weighted_ave = _ave3(cxk, fmat, weights)
cxk -= weighted_ave

if np.sum(np.abs(cxk - oldxk)) < tol:
converged = True
break

# update
oldxk = cxk.copy()

res[:,k] = cxk

return res

@njit
def _ave3(x, f, w):

N = len(x)

wx_dict = {}
w_dict = {}

# Compute weighted sums using a dictionary
for i in prange(N):
j = f[i]
if j in wx_dict:
wx_dict[j] += w[i] * x[i]
else:
wx_dict[j] = w[i] * x[i]

if j in w_dict:
w_dict[j] += w[i]
else:
w_dict[j] = w[i]

# Convert the dictionaries to arrays
wx = np.zeros_like(f, dtype=x.dtype)
w = np.zeros_like(f, dtype=w.dtype)

for i in range(N):
j = f[i]
wx[i] = wx_dict[j]
w[i] = w_dict[j]

# Compute the average
wxw_long = wx / w

return wxw_long

``````

To conclude, here is a benchmark against `PyHDFE`:

``````%load_ext autoreload

import numpy as np
import time

np.random.seed(1238)
N = 1_000_000
x = np.random.normal(0, 1, 4*N).reshape((N,4))
f1 = np.random.choice(list(range(1000)), N).reshape((N,1))
f2 = np.random.choice(list(range(1000)), N).reshape((N,1))

flist = np.concatenate((f1, f2), axis = 1)
weights = np.ones(N)

import pyhdfe
start_time = time.time()
algorithm = pyhdfe.create(flist)
res_pyhdfe = algorithm.residualize(x)
end_time = time.time()
print(end_time - start_time)
# 2.168398380279541

start_time = time.time()
algorithm = pyhdfe.create(flist)
res_pyfixest = demean(x, flist, weights, tol = 1e-08)
# Calculate the execution time
end_time = time.time()
print(end_time - start_time)
# 3.359426736831665

np.allclose(res_pyhdfe , res_pyfixest)
# True
``````

Any tips on how I could further optimize my implementation?

Best, Alex

Hey @s3alfisc ,
you are overwriting your input matrix cx. After executing function demean “cx” equals “res”. Is that a feature or a bug?
I could not find much to improve.
If you remove some redundant lines of code you gain about 20% execution time.
I hope you will find more gains.

``````@njit(parallel=True, cache=False, fastmath=False)
def demean(cx, flist, weights, tol=1e-10, maxiter=2000):
N, K = cx.shape
fixef_vars = flist.shape[1]
for k in prange(K):
cxk = cx[:, k]
oldxk = cxk - 1
for _ in range(maxiter):
for i in range(fixef_vars):
cxk -= _ave3(cxk, flist[:, i], weights)
if np.sum(np.abs(cxk - oldxk)) < tol:
break
oldxk[:] = cxk

@njit(parallel=False, cache=False, fastmath=False)
def _ave3(x, f, w):
N = len(x)
w_dict = Dict.empty(key_type=types.int64, value_type=types.float64,)
wx_dict = Dict.empty(key_type=types.int64, value_type=types.float64,)
for i in prange(N):
j = f[i]
if j in wx_dict:
wx_dict[j] += w[i] * x[i]
w_dict[j] += w[i]
else:
wx_dict[j] = w[i] * x[i]
w_dict[j] = w[i]
res = np.empty_like(x)
for i in prange(N):
j = f[i]
res[i] = wx_dict[j] / w_dict[j]
return res

# original: 6.06292986869812