Parallel shift/roll operations with JIT

Hi all,

First post here so please bear with me for style mistakes. I originally posted this topic to the Numba GitHub issues page but I’m hoping people here might be able to help also. You can see my original post here for a much more extensive example: jit/njit parallelization of np.roll() · Issue #7842 · numba/numba · GitHub

I’m using python/numba/jit to prototype a CFD code so I need to be able to do element-wise arithmetic on large arrays, including shifting to get neighboring values for estimating the spatial derivatives. When using np.roll() the parallel analysis shows that the shift operations are not recognized as parallel loops. E.g., an operation like sarr = np.roll(arr,-1) is not given a loop number. I redefined my own local version called roll() using prange and if I use that for shift operations it is recognized as parallel, but those can’t be fused with other operations so the code is not efficient.

Below is a short code for comparison. It contains two versions of the same function, with different implementations of roll().

import numpy as np
from numba import njit, prange 
import timeit

# version with local roll() definition
@njit(parallel=True)
def rs_roll_wrapper(arr):
    def roll(arr,shift):
        out = np.empty(arr.size, dtype=arr.dtype)
        for i in prange(arr.size):
            idx = (i - shift)%arr.size
            out[idx] = arr[i]
            return(out)

    temp1 = roll(arr,-1)
    temp1 *= arr
    temp2 = roll(arr,+1)
    temp2 = arr + temp1
    out = temp2 * temp2

    return(out)

# version with np.roll()
@njit(parallel=True)
def np_roll_wrapper(arr):

    roll = np.roll

    temp1 = roll(arr,-1)
    temp1 *= arr
    temp2 = roll(arr,+1)
    temp2 = arr + temp1
    out = temp2 * temp2

    return(out)

and here’s a script to test them

np_roll_wrapper(np.arange(int(1e2)))
stime = timeit.default_timer()
np_roll_wrapper(np.arange(int(1e6)))
print(timeit.default_timer() - stime)
rs_roll_wrapper(np.arange(int(1e2)))
stime = timeit.default_timer()
rs_roll_wrapper(np.arange(int(1e6)))
print(timeit.default_timer() - stime) 

which outputs

0.013500430999556556
0.00415644699933182

Looking at the parallel diagnostics, the two optimized versions end up looking like this:

the numpy version

Parallel loop listing for  Function np_roll_wrapper, <ipython-input-84-57f1e14f0d7f> (18) 
-----------------------------|loop #ID
@njit(parallel=True)         | 
def np_roll_wrapper(arr):    | 
                             | 
    roll = np.roll           | 
                             | 
    temp1 = roll(arr,-1)     | 
    temp1 *= arr-------------| #1740
    temp2 = roll(arr,+1)     | 
    temp2 = arr + temp1------| #1738
    out = temp2 * temp2------| #1739
                             | 
    return(out)              | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...

Fused loop summary:
+--1738 has the following loops fused into it:
   +--1739 (fused)
Following the attempted fusion of parallel for-loops there are 2 parallel for-
loop(s) (originating from loops labelled: #1740, #1738).

my version

Parallel loop listing for  Function rs_roll_wrapper, <ipython-input-84-57f1e14f0d7f> (1) 
--------------------------------------------------------|loop #ID
@njit(parallel=True)                                    | 
def rs_roll_wrapper(arr):                               | 
    def roll(arr,shift):                                | 
        out = np.empty(arr.size, dtype=arr.dtype)       | 
        for i in prange(arr.size):    ------------------| #1744, 1743
            idx = (i - shift)%arr.size                  | 
            out[idx] = arr[i]                           | 
        return(out)                                     | 
                                                        | 
    temp1 = roll(arr,-1)                                | 
    temp1 *= arr----------------------------------------| #1745
    temp2 = roll(arr,+1)                                | 
    temp2 = arr + temp1---------------------------------| #1741
    out = temp2 * temp2---------------------------------| #1742
                                                        | 
    return(out)                                         | 
--------------------------------- Fusing loops ---------------------------------
Attempting fusion of parallel loops (combines loops with similar properties)...

Fused loop summary:
+--1741 has the following loops fused into it:
   +--1742 (fused)
Following the attempted fusion of parallel for-loops there are 4 parallel for-
loop(s) (originating from loops labelled: #1744, #1745, #1743, #1741).

So, the first key difference is that in my version each call to roll() implies a loop within the definition of the function, while the same is not true for np.roll(). So, does that mean that np.roll() is not a parallel loop, or is it not a loop at all? And if the numpy version has fewer loops (2 vs 4) then why is it half as fast?

The second issue is systemic to both versions – every call to roll(), whichever version is used, is unfuzeable with other loops, so even though the loops are all of the same size and use basic arithemtic at most, the execution is slowed by them being broken up into separate loops. Whereas, by comparison, if I explicitly redefine everything in terms of indices instead of using vector notation I can do it all in a single loop that is approximately 4 times faster (being 1 loop instead of 4).

So, that’s all to say, has anybody worked on or implemented a version of np.roll that is natively parallelizeable and can be fuzed with other loops?

What version of Numba are you using for this? Your custom function already fails for me on the in-place integer multiplication when doing temp1 *= arr:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function imul>) found for signature:
 
 >>> imul(OptionalType(array(int32, 1d, C)), array(int32, 1d, C))

This is with Numba 0.55.1. It can be fixed by either setting parallel=False, or explicitly making it element-wise using temp1[:] *= arr[:].

You also return the result immediately after the first iteration, that probably should be done? And you should probably add the shift, to be compatible with np.roll, in the current implementation should shift exactly in the opposite direction. Or alternatively use out[idx] = arr[i].

And is the assumption that the input array is always 1D safe to make in this case? Since np.roll works for nD arrays as well, it first flattens them if no axis is specified.

edit:
A more specialized version could be something like:

@njit(parallel=True)
def rs_roll_wrapper(arr):
    
    out = np.empty_like(arr)
    n = out.size
    
    for i in prange(n):
        tmp1 = arr[(i+1) % n]
        tmp1 *= arr[i]
        # tmp2 = arr[(i-1) % n] # ??
        tmp2 = arr[i] + tmp1
        out[i] = tmp2 * tmp2
        
    return out

That gives me about a 3x speedup over the use of Numpy. I’m not sure if it’s a fair comparison because your first creation of temp2 isn’t used anywhere, I’m not sure if Numba recognizes that and skips it.

Parallel loop listing for  Function rs_roll_wrapper, ...\1394747790.py (68) 
----------------------------------------|loop #ID
@njit(parallel=True)                    | 
def rs_roll_wrapper(arr):               | 
                                        | 
    out = np.empty_like(arr)            | 
    n = out.size                        | 
                                        | 
    for i in prange(n):-----------------| #146
        tmp1 = arr[(i+1) % n]           | 
        tmp1 *= arr[i]                  | 
        # tmp2 = arr[(i-1) % n] # ??    | 
        tmp2 = arr[i] + tmp1            | 
        out[i] = tmp2 * tmp2            | 
                                        | 
    return out                          | 
------------------------------ After Optimisation ------------------------------
Parallel structure is already optimal.
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------

Posting to fix a transcription typo in the definition of roll() – the return command was improperly tabbed.

import numpy as np
from numba import njit, prange 
import timeit

# version with local roll() definition
@njit(parallel=True)
def rs_roll_wrapper(arr):
    def roll(arr,shift):
        out = np.empty(arr.size, dtype=arr.dtype)
        for i in prange(arr.size):
            idx = (i - shift)%arr.size
            out[idx] = arr[i]
        return(out)

    temp1 = roll(arr,-1)
    temp1 *= arr
    temp2 = roll(arr,+1)
    temp2 = arr + temp1
    out = temp2 * temp2

    return(out)

# version with np.roll()
@njit(parallel=True)
def np_roll_wrapper(arr):

    roll = np.roll

    temp1 = roll(arr,-1)
    temp1 *= arr
    temp2 = roll(arr,+1)
    temp2 = arr + temp1
    out = temp2 * temp2

    return(out)

You’re right, the return after every iteration was a transcription mistake – I’ve updated that and posted a new version below. That also fixes the TypingError when I just tested it. And good catch on the indexing mistake – it should be out[i] = arr[idx]. Still, I don’t see that having any effect on the parallelization of the code. I’m not too bothered about it being 1D safe – the final code can have more bells and whistles to deal with higher dimensionality – at the moment I really just want to figure out how to make shift operation fuzeable with vector loops.

I’ve tried making the arithmetic element-wise as you say with explicit indexing but this seems to break the parallelization, i.e. temp1*=arr is a parallel loop that can be fuzed, while temp1[:]*=arr[:] is a parallel loop that can’t be fuzed. I even tried defining an index array with

idx = (np.arange(arr.size) - shift) % arr.size
temp1 = arr[idx]

or, alternatively,

idx = (np.arange(arr.size) - shift) % arr.size
temp1[:] = arr[idx]

but again, for former is not recognized as a parallel loop, while the latter is parallel but unfuzeable.

Your specialized version that explicitly indexes all of the arithmetic and shift operations in a single loop is definitely the fastest way, and I’ve implemented a similar version myself, but it seems an inelegant solution to me – part of the beauty of using NumPy is that it makes the coding so much cleaner without explicit indexing and I would much prefer not to have every array indexed as temp1[i][j][k] *= arr[i][j][k] when I move to 3D and have hundreds of lines to code in this way…