Using JIT makes my method even slower.
import numpy as onp
import numba
import numba.typed as nbt
import time
import multiprocessing as mp
def f_pyo(mats, exps, masks):
#
num_groups = max((len(mats), len(exps), len(masks)))
results = nbt.List()
for i in range(num_groups):
#
results.append(onp.zeros_like(mats[i]))
for i in numba.prange(num_groups):
#
if masks[i]:
#
results[i][:] = onp.linalg.matrix_power(mats[i], onp.sum(exps[i]))
else:
results[i][:] = onp.linalg.matrix_power(mats[i], 0)
return results
f_jit1 = numba.njit(f_pyo, parallel=False)
f_jit2 = numba.njit(f_pyo, parallel=True)
nprng = onp.random.RandomState(47)
A = nbt.List()
exp = nbt.List()
mask = nbt.List()
for _ in range(30):
#
size = nprng.randint(300, 500 + 1)
p = nprng.randint(1, 16 + 1)
m = nprng.randint(0, 1 + 1)
A.append(nprng.uniform(0.0, 1.0, (size, size)).astype(onp.float64))
exp.append(onp.array(p).astype(onp.int64))
mask.append(onp.array(m).astype(onp.bool8))
f_pyo(A, exp, mask)
start = time.time()
f_pyo(A, exp, mask)
elapsed = time.time() - start
print("PyObject:", elapsed)
f_jit1(A, exp, mask)
start = time.time()
f_jit1(A, exp, mask)
elapsed = time.time() - start
print("JIT :", elapsed)
assert mp.cpu_count() > 4
numba.set_num_threads(4)
f_jit2(A, exp, mask)
start = time.time()
f_jit2(A, exp, mask)
elapsed = time.time() - start
print("JIT-4 :", elapsed)
Here is resources allocation script by SLURM:
#!/bin/bash
#SBATCH --job-name=benchmark
#SBATCH --output=benchmark.stdout.txt
#SBATCH --error=benchmark.stderr.txt
#SBATCH --partition=...
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --gres=gpu:0
#SBATCH --cpus-per-task=5
#SBATCH --gpus-per-task=0
#
python benchmark.py
I got
PyObject: 0.06118154525756836
JIT : 0.0636894702911377
JIT-4 : 0.09839320182800293
Clearly, JIT is slower, and JIT with paralleling is even slower.
I have tried several times, and the performance are similar.
However, I am not quite sure about why this is happening.
(It seems like that nbt.typed.List
, if ... else ...
branching and numpy.linalg.matrix_power
are all possible reasons.)