I am doing some benchmarking on accelerators for python. In below example, tensorflow runs at about 56% of the numba runtime:
import numpy as np
import tensorflow as tf
from numba import njit, prange
@tf.function
def compute_tf(m, n):
x1 = tf.range(0, m-1, 1) ** 2
x2 = tf.range(0, n-1, 1) ** 2
return x1[:, None] + x2[None, :]
compute_tf(tf.constant(1), tf.constant(1))
m = 50000
n = 10000
%timeit compute_tf(m, n)
557 ms ± 30.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@njit(parallel=True)
def compute_numba(m, n):
x = np.empty((m, n))
for i in prange(m):
for j in prange(n):
x[i, j] = i**2 + j**2
return x
compute_numba(1, 1)
%timeit compute_numba(m, n)
995 ms ± 38.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
This is a very simple computation, so I don’t really see why the TF-version would run any faster. Do you have any idea on how I can make the numba-version run on par with TF?
Double check that the programs are equivalent, the numba one is doing n*m pow(x,2) operations but the tensorflow one is doing only n + m.
Good point, thanks @DannyWeitekamp. However, I rewrote the numba version to be exactly equivalent, and the results stay the same:
@njit(parallel=True)
def compute_numba(m, n):
x1 = np.arange(0, m-1)**2
x2 = np.arange(0, n-1)**2
return x1.reshape(-1, 1) + x2.reshape(1, -1)
compute_numba(1, 1)
%timeit compute_numba(m, n)
1.11 s ± 9.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Are the dtypes the same throughout? Is Tensorflow producing an int32
from range
in the case of integer inputs: https://www.tensorflow.org/api_docs/python/tf/range ?
Also, if you strength reduce x**2
to x * x
does that help?
That was it, TF was producing int32. Also, taking **2 on the arange with type np.int32 converted to np.int64 in numba, but strength reducing preserved type which put numba at par. I guess it makes sense then that numba took about twice the time, since it was allocating twice the memory. Many thanks @stuartarchibald!
@bdch1234 No problem, glad it’s resolved
Also, going back to the first implementation of numba actually makes it slightly faster than TF now:
@njit(parallel=True)
def compute_numba_2(m, n):
x = np.empty((m, n), np.int32)
for i in prange(n):
for j in prange(m):
x[i, j] = j**2 + i**2
return x
compute_numba_2(1, 1)
%timeit compute_numba_2(m, n)
426 ms ± 14.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
And further improving by strength reducing the inner loop, makes numba run at about 25% of the TF time!
@njit(parallel=True)
def compute_numba_3(m, n):
x = np.empty((m, n), np.int32)
for i in prange(n):
for j in prange(m):
x[i, j] = j*j + i*i
return x
compute_numba_3(1, 1)
%timeit compute_numba_3(m, n)
142 ms ± 4.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)