Volume 3d rotation

Hi, i am using numba to accelerate a volume rotation sampling + mip (maximum intensity projection). I am happy about the results as it is much faster than torch affine grid + grid_sample (in cpu) or custom affine grid + scipy map_coordinates (a factor of 100 x almost).

I am putting the code in case someone sees an obvious optimization.

import numba
import math
import numpy as np

@numba.jit(nopython=True, fastmath=True, parallel=True)
def affine_sample_mip(volume, mats):
    r = mats[:3,:3]
    t = mats[:3,3]
    d,h,w = volume.shape
    d2,h2,w2 = int(d/2), int(h/2), int(w/2)
    out = np.zeros((d,h), dtype=volume.dtype)
            
    for i in numba.prange(d):
        for j in range(h):
            v = 0
            for k in range(w):
                z = i * r[0,0] + j * r[0,1] + k * r[0,2] + t[0]
                y = i * r[1,0] + j * r[1,1] + k * r[1,2] + t[1]
                x = i * r[2,0] + j * r[2,1] + k * r[2,2] + t[2]

                # 4. a) nearest interpolation
#                 if z < 0 or z >= d or y < 0 or y >= h or x < 0 or x >= w:
#                     continue  
#                 v = max(v, volume[int(z), int(y), int(x)])
                
                # 4. b) linear interpolation (slower but looks nicer)
                avg = 0
                ws = 1e-6
                for ii in range(2):
                    for jj in range(2):
                        for kk in range(2):
                            z2 = int(z + ii)
                            y2 = int(y + jj)
                            x2 = int(x + kk)
                            if z2 < 0 or z2 >= d or y2 < 0 or y2 >= h or x2 < 0 or x2 >= w:
                                continue  
                                 
                            w1 = max(0, 1-abs(z-z2))
                            w2 = max(0, 1-abs(y-y2))
                            w3 = max(0, 1-abs(x-x2))
                            w4 = w1*w2*w2
                            ws += w4
                            avg += w4 * volume[z2,y2,x2]          
                v = max(v, avg / ws)
            out[i,j] = v
    return out

you can test with any 3d volume using the following affine matrix code:

import numpy as np
from scipy.spatial.transform import Rotation as R


def get_trans_matrix(trans):
    b = len(trans)
    trans_mat = np.zeros((b,4,4), dtype=np.float32)
    trans_mat[:,[0,1,2],[0,1,2]] = 1 # eye
    trans_mat[:,3,3] = 1   # offset
    trans_mat[:,:3,3] = trans # translation
    return trans_mat


def get_rot_matrix(rots):
    b = len(rots)
    rot = R.from_rotvec(rots)
    rot_mat = np.zeros((b,4,4), dtype=np.float32)
    rot_mat[:,3,3] = 1
    rot_mat[:,:3,:3] = rot.as_matrix()
    return rot_mat


def get_scale_matrix(scales):
    b = len(scales)
    scale_mat = np.zeros((b,4,4), dtype=np.float32)
    scale_mat[:,3,3] = 1
    scale_mat[:,[0,1,2],[0,1,2]] = scales
    return scale_mat


def get_affine_matrix(d, h, w, trans, scales, rots, order='bw'):
    """
    PreMultiply Matrices 4,4
    - offset with [d/2,h/2,w/2]
    - substract translation vector [tz,ty,tx]
    - divide scale vector [sz,sy,sx]
    - rotation matrix R
    """
    b = len(trans)
    offset = np.array([d/2,h/2,w/2], dtype=np.float32)[None].repeat(b,0)

    if order == 'bw':
        offset_mat = get_trans_matrix(offset-trans)
        scale_mat = get_scale_matrix(1./scales) 
        rot_mat = get_rot_matrix(rots)
        offset_mat2 = get_trans_matrix(-offset)
    else:
        offset_mat = get_trans_matrix(offset-trans)
        scale_mat = get_scale_matrix(1./scales) 
        rot_mat = get_rot_matrix(rots).transpose(0,2,1)
        offset_mat2 = get_trans_matrix(-offset)

    aff_mat = offset_mat@rot_mat@scale_mat@offset_mat2
    return aff_mat

you can test this with:

x = np.random.randn(128,128,128) # or a more interesting volume
scales = np.ones((1,3), dtype=np.float32) * 1.0
trans = np.zeros((1,3), dtype=np.float32) + 0
rots = np.zeros((1,3), dtype=np.float32)
rots[:,0] = 0
rots[:,1] = np.pi/4
rots[:,2] = np.pi/4

d,h,w = x.shape[-3:]
mats = get_affine_matrix(d,h,w, trans, scales, rots, order='bw')

#%timeit ns.affine_sample_mip(volume.squeeze(), mats[0])
print('numba mip: ', timeit(20, ns.affine_sample_mip, x.squeeze().astype(np.uint8), mats[0]), ' s')

Overall it seems pretty reasonable to me. Nothing I noticed that was obviously very suboptimal.

1 Like

Dear @etienne87

I also think this looks well written. However, think carefully about the following condition:

if z2 < 0 or z2 >= d or y2 < 0 or y2 >= h or x2 < 0 or x2 >= w:
    continue

Once z2 >= d, you can break out of all three innermost loops, since you will always continue anyway, because z2 can only get bigger. The same goes for y2 and the two innermost loops. If x2 >= w you want to break and not continue the innermost loop.
I haven’t tested if this really improves performance, but it might be worth a try. If you try it, test it extensively to make sure there is no logical mistake.

You also have some dead code, e.g. d2, h2 and w3 remain unused. Numba will handle this, but why not clean it up yourself.

Hope this helps!

Edit: Maybe, it is even worth putting a condition in front of the three innermost loops to check if you should enter them. E.g.:

if (int(z + 1) < 0) or (int(y + 1) < 0) or (int(x + 1) < 0) or (int(z) >= d) or (int(y) >= h) or (int(x) >= w):
    continue 

You can also test if manually unrolling (parts of) the three innermost loops helps.

1 Like

the early stopping gives around 5% boost!
will try now the unrolling