Optimizing parallelized similarity function

Hi numba community,

I have relatively simple piece of code which I am parallelizing. It works for smaller datasets as it is now, but I do not see the parallel scaling I would expect. I am new to using numba and as such my implementation is probably far from optimal. It is my hope that someone here point me in the right direction.

What I’m doing in brief:

I have datasets with 1k-12k rows (variables) and 400k-1.2m columns (samples) and I am calculating a similarity metric between all samples based on cosine similarity scaled by the number of non-zero overlaps. This generates a rather large square adjacency matrix which I preallocate and fill out using np.memmap() to avoid memory constraints.

Code:

import numpy as np
import pandas as pd

from numba import njit, prange
from numba_progress import ProgressBar

# -------------------------------------
# functions

@njit
def nz_overlap_count_func(array):
    """
    Returns the count of non-zero overlaps between two 1D arrays.
    """
    n = 0
    for i in array:
        if i != 0:
            n += 1
    return n

@njit(nogil=True, parallel=True)
def cosines_memmap(input_data, output_memmap, progress_proxy):
    """
    Calculates the cosine similarity between overlapping variables of 
    all columns in input array and updates an array stored on disk
    with the values.
    
    Input:
        input_data: A 2D np.array of np.float32
        output_memmap: A 2D np.array stored on disk using np.memmap()
        progress_proxy: A ProgressBar object to track progress 
    """
    
    # Loop over all samples, i
    for i in prange(input_data.shape[1]):
        
        query = input_data[:,i]
        
        # Loop over all samples j >= i
        for j in range(i, input_data.shape[1]):
            
            reference = input_data[:,j]
            
            # Get number of non-zero overlaps between query and reference
            ol = nz_overlap_count_func(query*reference)
            if ol == 0:
                ol = 1
            
            # Calculate cosine similarity from overlapping variables between samples i and j, 
            # scale it by the number of overlaps between samples i and j,
            # then update the (i,j)'th entry in the output array on disk
            q_denom = np.sqrt(np.sum(query[reference.nonzero()]**2))
            if q_denom == 0:
                q_denom = 1
            
            r_denom = np.sqrt(np.sum(reference[query.nonzero()]**2))
            if r_denom == 0:
                r_denom = 1
            
            output_memmap[i,j] = ((query / q_denom) @ (reference / r_denom)) * ((1-(1/ol))**2)
        
        progress_proxy.update(1)
        
# -------------------------------------
# Some test data
input_shape = (1000,5800)
input_data = np.random.default_rng(seed=12).normal(size=input_shape).astype(np.float32)
input_data[np.abs(input_data) < 2] = 0
output_arr = np.memmap('/path/to/output.file', dtype=np.float32, mode='w+', shape=(input_shape[1],input_shape[1]))

# -------------------------------------
# Run to compile
first_i = np.random.default_rng(seed=12).normal(size=(5,15))
first_o = np.memmap('/path/to/output_f.file', dtype=np.float32, mode='w+', shape=(first_i.shape[1],first_i.shape[1]))
with ProgressBar(total=first_i.shape[1]) as progress:
    cosines_memmap(first_i, first_o, progress)

# Run actual
set_num_threads(8)
with ProgressBar(total=input_shape[1]) as progress:
    cosines_memmap(input_data, output_arr, progress)

As the adjacency matrix is mirrored across the diagonal I only calculate the upper triangle (the j-loop). Thus when I run it with parallel=False I would expect iterations/sec to increase over the run as we are making fewer and fewer calculations for each i. This is indeed what I observe, and I get a runtime of ~2min 50sec for the given example.

However for parallel=True I no longer see this behaviour: In fact, iterations/sec decreases over the run. Perhaps this has to do with how numba splits up the work between the threads, where it does not take into account that the workload decreases along the i loop. Also, for the given example I see no further decreases in runtime past ~8 threads which finishes in ~45 seconds.

Eventually I get rid of the dense adjacency matrix by setting a threshold under which everything is set to zero and save that as a sparse matrix. Overall this feels wasteful, and my initial approach was to append (i,j,value) to a list if value was above some threshold. From that I could construct a sparse COO matrix without ever holding the full dense adjacency matrix in memory. At that point I would no longer have to use a memory mapped file on disk, which I assume would also speed the whole thing up. This unfortunately cannot be parallelized as I do not know the length of the list in advance, which is why I have ended up with the memory mapped approach.

You can quite easily make this ~100x faster by changing the data layout from (variable, sample) to (sample, variable). This will allow for a memory access pattern that uses contiguous data slices which will limit cache misses (which are probably happening on every single read in your current example). When I transpose your example so that it is generating a shape (5800, 1000) array instead of (1000,5800), and then swap input_data[:,i] and input_data[:,j] for input_data[i] and input_data[j] it goes from taking 35 seconds to taking 429 milliseconds. You can eck out a little more speed by writing some of the numpy bits that building intermediate arrays by expanding them into full loops. For instance the lines like: q_denom = np.sqrt(np.sum(query[reference.nonzero()]**2)) can be replaced by the function below.

@njit
def sum_of_sqrs_other_nonzero(a1, a2):
    s = 0
    for i in range(len(a1)):
        if(a2[i] != 0):
            s += a1[i]**2
    return s

#usage
q_denom = np.sqrt(sum_of_sqrs_other_nonzero(query, reference))

Doing this takes it from 429 ms to 382ms. Could probably do the same with nz_overlap_count_func(query*reference) and get a little more speed.

As for making parallelization more efficient. There might be some way of reducing i and j to a single variable let’s call it k which would be your parallelized iterator, then recover i and j from k. I don’t know the equations for that off the top of my head for upper triangular matrices… in fact doing something like that might be kind of hard, so maybe it’s not worth it. After all it’s already 100x faster when you use a good memory access pattern so maybe it doesn’t matter.

Also if you absolutely must have your data layout be (variable, sample) you might check out using the fortran-contiguous layout for your data.

4 Likes

Thank you very much for your reply. Your suggestions are great and exactly the kind of insight I was looking for. Much appreciated!

Writing out the the subsetting operations as loop functions did indeed give me a good boost, cutting runtime from the initial 45 to around 30 seconds.

As you suggested I also swapped the shape of the example data, then changed the indexing for reference and query, as well as the indexing into the shapes where applicable, and I see no change to runtime from that. The only way I can replicate such a speed increase is if I do not change the shape indexes for the i and j loop. In that case we have inverted the goal and are now looping over the 1000 variables rather than the 5800 genes which finds the similarity between the variables, not the samples. But perhaps I have missed something?

The code after the changes:

import numpy as np
import pandas as pd

from numba import njit, prange, set_num_threads
from numba_progress import ProgressBar

# -------------------------------------
# functions

@njit
def nz_overlap_count_func(a, b):
    n = 0
    for i in range(len(a)):
        p = a[i] * b[i]
        if p != 0:
            n += 1
    return n

@njit
def sum_of_sqrs_other_nonzero(a1, a2):
    s = 0
    for i in range(len(a1)):
        if(a2[i] != 0):
            s += a1[i]**2
    return s

@njit(nogil=True, parallel=True)
def cosines_memmap(input_data, output_memmap, progress_proxy):
    """
    Calculates the cosine similarity between overlapping variables of 
    all columns in input array and updates an array stored on disk
    with the values.
    
    Input:
        input_data: A 2D np.array of np.float32
        output_memmap: A 2D np.array stored on disk using np.memmap()
        progress_proxy: A ProgressBar object to track progress 
    """
    
    for i in prange(input_data.shape[0]):
        
        query = input_data[i]
        
        for j in range(i, input_data.shape[0]):
            
            reference = input_data[j]
            
            ol = nz_overlap_count_func(query, reference)
            if ol == 0:
                ol = 1
            
            q_denom = np.sqrt(sum_of_sqrs_other_nonzero(query, reference))
            if q_denom == 0:
                q_denom = 1
            
            r_denom = np.sqrt(sum_of_sqrs_other_nonzero(reference, query))
            if r_denom == 0:
                r_denom = 1
            
            output_memmap[i,j] = ((query / q_denom) @ (reference / r_denom)) * ((1-(1/ol))**2)
        
        progress_proxy.update(1)
        
# -------------------------------------
# Some test data
input_shape = (5800,1000)
input_data = np.random.default_rng(seed=12).normal(size=input_shape).astype(np.float32)
input_data[np.abs(input_data) < 2] = 0
output_arr = np.memmap('/path/to/output.file', dtype=np.float32, mode='w+', shape=(input_shape[0],input_shape[0]))

# -------------------------------------
# Run to compile
first_i = np.random.default_rng(seed=12).normal(size=(15,5))
first_o = np.memmap('/path/to/output_f.file', dtype=np.float32, mode='w+', shape=(first_i.shape[0],first_i.shape[0]))
with ProgressBar(total=first_i.shape[0]) as progress:
    cosines_memmap(first_i, first_o, progress)

# Run actual
set_num_threads(8)
with ProgressBar(total=input_shape[0]) as progress:
    cosines_memmap(input_data, output_arr, progress)

My apologies, it seems that in my reproduction I had forgotten to change input_data.shape[1] from to input_data.shape[0]. So the code wasn’t running the full breadth. I fooled myself. Fixing bad memory access patterns on these sorts of big numeric problems can regularly result in 10x speedups, but I should have been more skeptical of the 100x I was seeing. Regardless what you have above is about 2.7x faster than the original on my machine (35s → 13s). It’s possible that there are some OS or hardware specific quirks leading to the difference (I’m running an Ubuntu laptop). I also took out your progress bar if you want to try that. But, glad to see it’s at least 33% faster on your end.

The only other possible optimizations I can think are to calculate ol q_denom and r_denom in one loop, which might lead to fewer passes through the data. Again the benefit here would be in avoiding potential cache misses by reading into the source arrays more times than you need to.

1 Like

As you suggested, I was able to change the code such that only a single loop calculates all three of those values, which yielded another slight speed-up (down to ~22 seconds now). I also reimplemented the way I save my output such that I only need to store the upper triangle as a 1D array.

However at the end of the day I think I will have to reconsider my approach. The largest of my arrays would, until I filter it down to only the significant values, take up around 700Gb of space. While that is doable, it appears that as my output array grows the calculations slow down to a crawl. Over the iterations I can see that memory usage grows steadily, and I suspect that a large part of my array is being held in memory. As it gets larger and larger, a lot of time is probably spent shifting it around, which defeats the purpose in my case. Presumably that the issue is that I cannot call memmap.flush() inside the numba function, so more and more is loaded into memory and not off-loaded to disk until the function finishes.
I reckon I will have to come up with a clever way to chunk the data where each sample is still being compared against all others.

Either way, thank you very much for your help @DannyWeitekamp , it has been highly educational!

Hi @kongcav,

Just two quick tips:

  1. Check the math. (query @ reference) / q_denom / r_denom * ((1-(1/ol))**2) instead of ((query / q_denom) @ (reference / r_denom)) * ((1-(1/ol))**2) saves you a ton of operations.

  2. If this is really an issue, you can call memmap.flush from Numba using numba.objmode as explained here. This has a bit of an overhead, but I guess you would do this not too often.

Hope I did not overlook something. In any case, check carefully the first one!

Edit: As a third tip, I would follow @DannyWeitekamp’s advice and fuse the loops for better load balancing. It is not exactly trivial so here the solution:

def cosines_memmap(input_data, output_memmap, progress_proxy):
    n = input_data.shape[0]
    for k in prange(n*(n+1)//2):
        i = k // n
        j = k % n
        if i > j:
            i = n - i
            j = n - j - 1
       ...

Check especially this one very carefully!

1 Like

Hi @sschaer,

Great catch on the math, that saved me another few seconds on the previous implementation. The fused loop is absolutely brilliant, but I have not yet implemented it as I have changed my script a fair bit in order to chunk my data. I looked into using numba.objmode as you suggested, but at the end of the day I was not quite sure how to implement it.
So in order to deal with the memory problem I found a way using this excellent guide to split my triangular output matrix into more or less evenly sized chunks where I can control the size of each chunk. The difference is night and day.

The test script below runs in less than 10 seconds in total for me, using the same test input of dim (5800,100). That is from start to finish including data generation, disk i/o and compilation. For one of my real datasets of dim (580K,1000) which slowed down to a crawl before these changes (after 12 hours of running, it was still at 60% completion and taking up 500GB of memory on at 1.5TB shared HPC), it can now run in around 2 hours on 64 threads, definitely fast enough for my needs :slight_smile:

I will see if I can introduce the fused loop, I just need to come up with a formula for the count of values outputted each chunk for k (the output per batch is rectangular with only the “upper triangle” holding non-zero values).

Current code:

# ------------------------------------------------------------------------------------
# IMPORT LIBRARIES
import os
import math
import time
import datetime
import numpy as np
import pandas as pd

from tqdm import tqdm
from scipy import sparse
from numba import njit, prange, set_num_threads
from multiprocessing import cpu_count
# ------------------------------------------------------------------------------------
# DEFINE THE FUNCTIONS WE WILL USE
@njit
def n_overlaps_and_norms(a,b):
    n=0
    s1=0
    s2=0
    
    for i in range(len(a)):
        p = a[i] * b[i]
        
        if (p != 0):
            n += 1
        
        if (b[i] != 0):
            s1 += a[i]**2
        
        if (a[i] != 0):
            s2 += b[i]**2
            
    return n, np.sqrt(s1), np.sqrt(s2)

@njit(nogil=True, parallel=True, fastmath=True)
def cosine_function(query,reference):
    
    out = np.zeros((query.shape[0], reference.shape[0]), dtype=np.float32)
    
    for i in prange(query.shape[0]):
        
        q = query[i]
        for j in range(i, reference.shape[0]):
            
            r = reference[j]
            
            ol, q_denom, r_denom = n_overlaps_and_norms(q, r)
            if ol == 0:
                ol = 1
            if q_denom == 0:
                q_denom = 1
            if r_denom == 0:
                r_denom = 1
                
            out[i,j] = (q @ r) / q_denom / r_denom * ((1-(1/ol))**2)
    
    return out
# ------------------------------------------------------------------------------------
# MAKE SAMPLE DATA
input_shape = (5800,1000)
input_data = np.random.default_rng(seed=12).normal(size=input_shape).astype(np.float32)
input_data[np.abs(input_data) < 2] = 0
# ------------------------------------------------------------------------------------
# ALLOCATE NUMPY MEMMAP FILE FOR HOLDING DENSE OUTPUT
output_path = './output.file'
result = np.memmap(output_path, dtype=np.float32, mode='w+', shape=(input_shape[0],input_shape[0]))
del result
# ------------------------------------------------------------------------------------
# CHUNKING DATASET TO MANAGE MEMORY FOOTPRINT
n_chunks = 10
n = input_shape[0]

n_upper_tri = (n * (n+1))//2
n_upper_tri

per_chunk = int(math.ceil(n_upper_tri/(n_chunks)))

split_indices = [] # List containing the indices to split at
t = 0
for i in range(n + 1):
    t += n - i
    if t >= per_chunk:
        split_indices.append(i)
        t = 0

s_indices = [0] + split_indices + [n]
# ------------------------------------------------------------------------------------
# RUN COSINE CALCULATION
set_num_threads(8)
for i in tqdm(range(len(s_indices)-1)):

    result = np.memmap(output_path, 
                       dtype=np.float32, 
                       mode='r+', 
                       shape=(input_shape[0],input_shape[0]))
    start = s_indices[i]
    end = s_indices[i+1]
    
    query = input_data [start:end]
    reference = input_data [start:]
    
    res = cosine_function(query,reference)
    
    result[start:end,start:] = res
    result.flush()
    del result
# ------------------------------------------------------------------------------------
# SPARSIFY OUTPUT
dense_result = np.memmap(output_path, 
                         dtype=np.float32, 
                         mode='r+', 
                         shape=(input_shape[0],input_shape[0]))

qt = 0.95
qt_result = np.quantile(np.abs(dense_result[dense_result != 0]), q=qt)

dense_result[np.abs(dense_result) < qt_result] = 0

sparse_result = sparse.csr_matrix(dense_result)
# ------------------------------------------------------------------------------------
# SAVE OUTPUT
save_path = os.path.split(output_path)
sparse_path = save_path[0] + '/' + save_path[1].split('.')[0] + '_SPARSE.npz'

sparse.save_npz(sparse_path, sparse_result)
# ------------------------------------------------------------------------------------
# CLEAN UP LARGE TEMP FILE
os.remove(output_path)

My deepest gratitude to both of you, generating these adjacency matrices has been a headache and a half for me!

Glad you made progress!

I have another tip that can save you time, depending on your data. Notice that the expression (q @ r) / q_denom / r_denom * ((1-(1/ol))**2) equals zero when ol equals one. This happens when ol is calculated to be zero. So I suggest calculating ol first, and then calculating q_denom and r_denom only when needed. It works even better if you calculate q @ r already while calculating q_denom and r_denom. I tried this on your preveous code (not the one you just posted) and it worked fine (see code below). And the fused loop is definitely worth investigating. Monitor the CPU usage for the double loop and the fused loop. Your example is brilliant to show how effective this is.

@njit(nogil=True)
def sums_of_nonzero_sqrs_and_dot(a1, a2):
    s1, s2, s3 = 0, 0, 0
    for i in range(len(a1)):
        if (a2[i] != 0) and (a1[i] != 0):
            s1 += a1[i]**2
            s2 += a2[i]**2  
            s3 += a1[i]*a2[i]
    return s1, s2, s3

@njit(nogil=True, parallel=True, error_model="numpy")
def cosines_memmap_opt(input_data, output_memmap, progress_proxy):
    n = input_data.shape[0]
    for k in prange(n*(n+1)//2):
        i = k // n
        j = k % n
        if i > j:
            i = n - i
            j = n - j - 1

        query = input_data[i]
        reference = input_data[j]

        ol = nz_overlap_count_func(query, reference)
        
        if ol == 0:
            output_memmap[i,j] = 0
        else:
            q_denom, r_denom, query_dot_reference = sums_of_nonzero_sqrs_and_dot(query, reference)
            q_denom = 1 if q_denom == 0 else np.sqrt(q_denom)
            r_denom = 1 if r_denom == 0 else np.sqrt(r_denom)  
            output_memmap[i,j] = query_dot_reference / (q_denom * r_denom) * ((1-(1/ol))**2)
        
        progress_proxy.update(1)

I implemented the fused loop and it makes quite a bit of a difference - when looking at CPU usage the assigned CPUs are basically permanently pegged at 100%, whereas before there would more variation - temporary dips to 60-70 was common. This also shaves off another few seconds for the example synthetic data, but actually ~½-1 hour for the real datasets.

I suppose what happens is that since the workload for each index i is very uneven (the first row of my output will have n entries, the following row n-1 etc) there will be many instances of CPUs “finishing early”, leading to slightly disorganized task assignment - the fused loop somehow enables Numba to better distribute the workload across the CPUs?