Hi,
Thank you for replying to my query.
Yes I was aware of the compilation time so I know it is not that.
def kernel(zr, zi, cr, ci, lim, cutoff):
count = 0
while ((zr*zr + zi*zi) < (lim*lim)) and count < cutoff:
zr, zi = zr * zr - zi * zi + cr, 2 * zr * zi + ci
count += 1
return count
kernel_njit = njit()(kernel)
def plot_mandel(mandel):
plt.imshow(mandel)
plt.axis('off')
plt.show()
def compute_mandel_py(cr, ci, N, bound=1.0, lim=1000.0, cutoff=1e6):
mandel = np.empty((N, N), dtype=int)
grid_x = np.linspace(-bound, bound, N)
t0 = time.time()
for i, x in enumerate(grid_x):
for j, y in enumerate(grid_x):
mandel[i,j] = kernel(x, y, cr, ci, lim, cutoff)
return mandel, time.time() - t0
def compute_mandel_njit(cr, ci, N, bound=1.0, lim=1000.0, cutoff=1e6):
mandel = np.empty((N, N))
grid_x = np.linspace(-bound, bound, N)
t0 = time.time()
for i, x in enumerate(grid_x):
for j, y in enumerate(grid_x):
mandel[i,j] = kernel_njit(x, y, cr, ci, lim, cutoff)
return mandel, time.time() - t0
compute_mandel_njit_jit1 = jit()(compute_mandel_njit)
compute_mandel_njit_jit2 = jit(forceobj=True, looplift=True)(compute_mandel_njit)
def python_run():
kwargs = dict(cr=0.285, ci=0.01,
N=500,
bound=1.0)
print("Using pure Python")
mandel_func = compute_mandel_py
mandel_set, runtime = mandel_func(**kwargs)
print("Mandelbrot set generated in {} seconds".format(runtime))
#plot_mandel(mandel_set)
def njit_run():
kwargs = dict(cr=0.285, ci=0.01,
N=500,
bound=1.0)
print("Using njitted kernel")
mandel_func = compute_mandel_njit
mandel_set, runtime = mandel_func(**kwargs)
print("Mandelbrot set generated in {} seconds".format(runtime))
#plot_mandel(mandel_set)
def njit_jit_run1():
kwargs = dict(cr=0.285, ci=0.01,
N=500,
bound=1.0)
print("Using njitted kernel and jitted compute function")
mandel_func = compute_mandel_njit_jit1
mandel_set, runtime = mandel_func(**kwargs)
print("Mandelbrot set generated in {} seconds".format(runtime))
#plot_mandel(mandel_set)
def njit_jit_run2():
kwargs = dict(cr=0.285, ci=0.01,
N=500,
bound=1.0)
print("Using njitted kernel and jitted compute function in object mode & looplift")
mandel_func = compute_mandel_njit_jit2
mandel_set, runtime = mandel_func(**kwargs)
print("Mandelbrot set generated in {} seconds".format(runtime))
#plot_mandel(mandel_set)
And then running
njit_run()
njit_jit_run1()
njit_jit_run2()
At least twice (accounting for compilation) I get these times;
Using njitted kernel
Mandelbrot set generated in 0.15392279624938965 seconds
Using njitted kernel and jitted compute function
Mandelbrot set generated in 0.028262853622436523 seconds
Using njitted kernel and jitted compute function in object mode & looplift
Mandelbrot set generated in 0.15626192092895508 seconds
What I don’t understand is why
compute_mandel_njit_jit1 = jit()(compute_mandel_njit)
compute_mandel_njit_jit2 = jit(forceobj=True, looplift=True)(compute_mandel_njit)
These is why the first (only jit no options set) is much faster when in the warning it says it is using object mode with loop lifting enabled. If this was true both functions should give similar performance.
Is my question more clear now?
Thanks again for replying.
Fionnuala