I tried to use numba to realize array multiplication with symmetry, e.g., A[i,j] B[j,k,l] = C[i,k,l], and B[j,k,l]=B[j,l,k]. I tried to use loop to control the last two indices of array B and use numba to accelerate it. For the following code
import numpy as np
import time
from numba import njit
# Tensor shapes
i_dim = 20
j_dim = 20
k_dim = 20
l_dim = 20
# Create random tensors
np.random.seed(0) # Fix the random seed for reproducibility
A = np.random.rand(i_dim, j_dim)
# Create a random tensor B_raw
B_raw = np.random.rand(j_dim, k_dim, l_dim)
# Symmetrize B_raw along the c and d axes
B = np.zeros_like(B_raw)
for k in range(k_dim):
for l in range(l_dim):
B[:, k, l] = 0.5 * (B_raw[:, k, l] + B_raw[:, l, k])
# Direct einsum method
start_time = time.time()
C_einsum = np.einsum("ij,jkl->ikl", A, B)
end_time = time.time()
print("Direct einsum method:")
print(f"Time taken: {end_time - start_time:.5f} seconds")
@njit
def symm_einsum(A, B, C):
i_dim, j_dim = A.shape
j_dim, k_dim, l_dim = B.shape
for l in range(l_dim):
C[:, :l+1, l] = np.einsum("ij,jk->ik", A, B[:, :l+1, l])
for k in range(l):
C[:, l, k] = C[:, k, l]
return C
# Custom loop with symmetry constraint and Numba
C_symm = np.zeros((i_dim, k_dim, l_dim))
start_time = time.time()
C_symm = symm_einsum(A, B, C_symm)
end_time = time.time()
print("\nCustom loop with symmetry constraint and Numba:")
print(f"Time taken: {end_time - start_time:.5f} seconds")
# Check if the results are equal (within a tolerance)
print("\nAre the results equal?", np.allclose(C_einsum, C_symm, rtol=1e-05, atol=1e-08))
# Compare timings
einsum_time = end_time - start_time
symm_time = end_time - start_time
speedup_factor = einsum_time / symm_time
print(f"\nSpeed-up factor: {speedup_factor:.2f}")
I got ```Use of unsupported NumPy function ânumpy.einsumâ or unsupported use of the function. (comment out @njit will work)
File âslicing.pyâ, line 37:
def symm_einsum(A, B, C):
for l in range(l_dim):
C[:, :l+1, l] = np.einsum(âij,jk->ikâ, A, B[:, :l+1, l])
^````
chatgpt told me numba has limitation in slicing/array support, is there any solution in numba? if not, could it be possible to add this feature?