I’ve been working on a 3D volume raytracer in a spherical coordinate system for a scientific application.
The raytracer is a little unique in that I precompute the indices of the voxels that lie in the path of every ray instead of on-the-fly. This has allowed me to keep the implementation very simple, using only PyTorch array operations for everything at the expense of memory consumption.
For example, this is how I raytrace a 3D volume d
for a set of rays given by r_ind
, e_ind
, and a_ind
along with the lengths of intersections lens
of the rays with the voxels.
(d[r_ind, e_ind, a_ind] * lens).sum(axis=-1)
I am quite happy with how fast this is, but I would like to be able to handle more rays simultaneously without running out of memory (i.e. larger r_ind
, e_ind
, a_ind
). An obvious choice would have been to change r_ind
/e_ind
/a_ind
from int64 to int8, but unfortunately PyTorch only supports indexing by int64.
Here is an example script which computes the above function in PyTorch and compares it to an equivalent Numba implementation. Unfortunately, it’s about 100X slower or more than the PyTorch implementation. Any suggestions for speeding this up?
#!/usr/bin/env python3
import time
from contexttimer import Timer
from numba import cuda, void, int64, float32
import torch as t
t.manual_seed(0)
t.cuda.empty_cache()
spec = {'device': 'cuda'}
# volume being raytraced
shape = 50
d = t.rand((shape, shape, shape), **spec)
# width of detector
num_pix = 512
# maximum number of intersection points of each ray (this is specific to a spherical coord. system)
num_points = 2 * d.shape[0] + 2 * d.shape[1] + d.shape[2]
# voxel indices where rays intersect (placeholder)
ind = t.randint(shape, (num_pix, num_pix, num_points, 3), **spec)
# intersection lengths of rays with voxels (placeholder)
lens = t.rand(num_pix, num_pix, num_points, **spec)
# ----- PyTorch Raytracing -----
# index arrays with shape (num_pix, num_pix, num_points)
r, e, a = ind.moveaxis(-1, 0)
with Timer(prefix='PyTorch'):
# raytracing inner-product
# look up voxel indices for each ray and multiply by intersection length, then sum
result_torch = (d[r, e, a] * lens).sum(axis=-1)
# ----- Numba Raytracing -----
@cuda.jit(void(float32[:, :, :], int64[:, :, :], int64[:, :, :], int64[:, :, :], float32[:, :, :], float32[:, :]))
def raytrace(d, r, e, a, lens, result):
"""Unrolled version of PyTorch inner-product"""
x, y = cuda.grid(2)
if x < r.shape[0] and y < r.shape[1]:
inner_product = 0
for i in range(r.shape[2]):
r_ind = r[x, y, i]
e_ind = e[x, y, i]
a_ind = a[x, y, i]
len_ = lens[x, y, i]
inner_product += d[r_ind, e_ind, a_ind] * len_
result[x, y] = inner_product
# copy arrays to GPU
d_c = cuda.to_device(d)
r_c = cuda.to_device(r)
e_c = cuda.to_device(e)
a_c = cuda.to_device(a)
lens_c = cuda.to_device(lens)
result_numba_c = cuda.to_device(t.empty((num_pix, num_pix)))
with Timer(prefix='Numba'):
# use a single block for each ray
raytrace[(num_pix, num_pix), (1, 1)](d_c, r_c, e_c, a_c, lens_c, result_numba_c)
cuda.synchronize()
result_numba = result_numba_c.copy_to_host()
# ----- Compare Numerical Result -----
print('PyTorch result:', float(result_torch.sum()))
print('Numba result:', result_numba.sum())
PyTorch took 0.001 seconds
Numba took 0.085 seconds
PyTorch result: 16390173.0
Numba result: 16390176.0