Hi, the numba function I trying to parallelize is below:
@njit(cache=True,parallel=True)
def freeoh_count_jit(coord=np.array([[]]),\
molInterfaceIndex=np.array([]),\
hNeighbourList=np.array([]),\
topol=np.array([[]]),\
cos_HAngle=0.0,\
cellsize=np.array([]),\
is_orig_def=False,is_new_def=True):
NAtomsMol=3 #No. of atoms in a molecule
_M=molInterfaceIndex.shape[0]
_N=hNeighbourList.shape[1]
mol1Coord=np.zeros((NAtomsMol,3),dtype=np.float64)
mol2Coord=np.zeros((_N*NAtomsMol,3),dtype=np.float64)
acceptorArray=np.zeros((_M,2),dtype=numba.int64)
donorArray=np.zeros((_M,2),dtype=numba.int64)
cosAngle=np.zeros(_M,dtype=np.float64)
labelArray=np.empty(_M, dtype="U10")
for i in numba.prange(_M): # loop over selected molecules
#for i in range(_M-1,-1,-1): # loop over selected molecules
mol1Coord[0]=coord[topol[molInterfaceIndex[i],0]] # extract center molecule
mol1Coord[1]=coord[topol[molInterfaceIndex[i],1]] # extract center molecule
mol1Coord[2]=coord[topol[molInterfaceIndex[i],2]] # extract center molecule
for indexJ,j in enumerate(hNeighbourList[i]):
mol2Coord[0+NAtomsMol*indexJ]=coord[topol[j,0]] # extract neighbors
mol2Coord[1+NAtomsMol*indexJ]=coord[topol[j,1]] # extract neighbors
mol2Coord[2+NAtomsMol*indexJ]=coord[topol[j,2]] # extract neighbors
neighAtoms = len(np.array([index for index in hNeighbourList[i] if index!=-1]))*NAtomsMol # get actual number of neighbor atoms
if is_orig_def:
acceptorArray[i],donorArray[i],cosAngle[i]=interface_hbonding_orig(mol1Coord,mol2Coord[:neighAtoms],cos_HAngle,cellsize)
labelArray[i]="D"*np.abs(2-np.sum(donorArray[i]))+"A"*np.clip(np.array([np.sum(acceptorArray[i] )]),1,2)[0]
elif is_new_def:
acceptorArray[i],donorArray[i],cosAngle[i]=interface_hbonding_new(mol1Coord,mol2Coord[:neighAtoms],cos_HAngle,cellsize)
labelArray[i]="D"*np.abs(2-np.sum(donorArray[i]))+"A"*np.sum(acceptorArray[i] )
del_index=np.argwhere(cosAngle > 1.0)
freeOHCos=np.delete(cosAngle, del_index.flatten())
return acceptor, donor, labelArray, freeOHCos
The function takes in a coordinates frame of the simulation molecules. The code selects a central molecule from an index list molInterfaceIndex
and extracts its neighbouring molecules coordinates from a pre-generated neighbour (generated from scipy.spatial.KDTree
hence cannot be called from a jitted function). The central molecule and its neighbour are send to a jitted function which then returns two [1,1] arrays (and a float value) which are used to label the central molecule.
The issue I have here is that when using numba.prange
in the outer loop I get incorrect array values for the initial two frames. By comparing with the non-jitted and serial functions it becomes clear that this has something to do this the loop parallelization. However,it is not clear to me immediately whether I have encountered a race condition or I am having some other weird data issues. If anyone has any suggestions for rewriting this function, it would be really helpful.
Edit
The MWE is:
import numpy as np
import math
import numba
from collections import Counter
from numba import njit
@njit(cache=True)
def pbc_r2(i,j,cellsize):
k=np.zeros((3),dtype=numba.float64)
inv_cellsize = 1.0 / cellsize
xdist = j[0]-i[0]
ydist = j[1]-i[1]
zdist = j[2]-i[2]
k[0] = xdist-cellsize[0]*np.rint(xdist*inv_cellsize[0])
k[1] = ydist-cellsize[1]*np.rint(ydist*inv_cellsize[1])
k[2] = zdist-cellsize[2]*np.rint(zdist*inv_cellsize[2])
return k, k[0]**2+k[1]**2+k[2]**2
@njit(cache=True,parallel=True)
def interface_hbonding_new(mol1Coord,mol2Coord,cos_HAngle,cellsize,is_angle=True):
O1coord=mol1Coord[0]
H11coord=mol1Coord[1]
H12coord=mol1Coord[2]
# Important lesson for the free-OH interfacial count:
# Acceptor needs to be a len-2 1-d array with integer counter (0: no bonding)
# Donor needs to a len-2 1-d array
# Donor array 0/1 counters needs to be opposite of Acceptor (for free-OH lifetime calculations)
donor=np.ones((2),dtype=numba.int64)
acceptor=np.zeros((2),dtype=numba.int64)
cosAngle=1.0
r2O2H21=0.0;r2O2H22=0.0;r2O1H11=0.0;r2O1H12=0.0
rIVec=np.zeros((5,3), dtype=numba.float64)
rJVec=np.zeros((5,3), dtype=numba.float64)
crossDistDict = numba.typed.Dict.empty(key_type=numba.types.unicode_type,value_type=numba.types.float64)
#O1-1H distances
_,r2O1H11=pbc_r2(O1coord,H11coord,cellsize)
_,r2O1H12=pbc_r2(O1coord,H12coord,cellsize)
for j in numba.prange(int(mol2Coord.shape[0]/3)):
O2coord=mol2Coord[3*j]
H21coord=mol2Coord[3*j+1]
H22coord=mol2Coord[3*j+2]
#O-O intermolecular
_,r2O1O2=pbc_r2(O1coord,O2coord,cellsize)
#O1-2H distances
_,r2O1H21=pbc_r2(O1coord,H21coord,cellsize)
_,r2O1H22=pbc_r2(O1coord,H22coord,cellsize)
#O2-1H distances
_,r2O2H11=pbc_r2(O2coord,H11coord,cellsize)
_,r2O2H12=pbc_r2(O2coord,H12coord,cellsize)
crossDistDict={'angleA1':r2O1H21,'angleA2':r2O1H22, 'angleD1':r2O2H11, 'angleD2':r2O2H12}
#for key,value in crossDistDict.items():
# if value < minvalue:
# minkey = key
# minvalue=value
minkey=sorted([(val, key) for key, val in crossDistDict.items()])[0][1]
#print(minkey)
if minkey == 'angleA1':
_,r2O2H21=pbc_r2(O2coord,H21coord,cellsize)
cosAngleA1=(-r2O1H21+r2O2H21+r2O1O2)/(2*math.sqrt(r2O2H21*r2O1O2))
if cosAngleA1 > cos_HAngle:
acceptor[0]+=1
if minkey == 'angleA2':
_,r2O2H22=pbc_r2(O2coord,H22coord,cellsize)
cosAngleA2=(-r2O1H22+r2O2H22+r2O1O2)/(2*math.sqrt(r2O2H22*r2O1O2))
if cosAngleA2 > cos_HAngle:
acceptor[1]+=1
if minkey == 'angleD1':
cosAngleD1=(-r2O2H11+r2O1H11+r2O1O2)/(2*math.sqrt(r2O1H11*r2O1O2))
if cosAngleD1 > cos_HAngle:
donor[0]=0
if minkey == 'angleD2':
cosAngleD2=(-r2O2H12+r2O1H12+r2O1O2)/(2*math.sqrt(r2O1H12*r2O1O2))
if cosAngleD2 > cos_HAngle:
donor[1]=0
if is_angle:
if donor[0] ==1 and donor [1] == 0:
rO1H11,_=pbc_r2(O1coord,H11coord,cellsize)
cosAngle=np.sign(O1coord[2])*(rO1H11[2]/math.sqrt(r2O1H11))
elif donor[0] ==0 and donor [1] == 1:
rO1H12,_=pbc_r2(O1coord,H12coord,cellsize)
cosAngle=np.sign(O1coord[2])*(rO1H12[2]/math.sqrt(r2O1H12))
else:
cosAngle = 100.0
return acceptor, donor, cosAngle
@njit(cache=True,parallel=True)
def freeoh_count_jit(coord=np.array([[]]),\
molInterfaceIndex=np.array([]),\
hNeighbourList=np.array([]),\
topol=np.array([[]]),\
cos_HAngle=0.0,\
cellsize=np.array([]),\
is_orig_def=False,is_new_def=True):
NAtomsMol=3 #No. of atoms in a molecule
_M=molInterfaceIndex.shape[0]
_N=hNeighbourList.shape[1]
acceptorArray=np.zeros((_M,2),dtype=numba.int64)
donorArray=np.zeros((_M,2),dtype=numba.int64)
cosAngle=np.zeros(_M,dtype=np.float64)
neighAtoms=np.zeros(_M,dtype=numba.int64)
#freeOHMask = np.zeros(_M, dtype=numba.int64) == 0
labelArray=np.empty(_M, dtype="U10")
for i in numba.prange(_M): # loop over selected molecules
mol1Coord=np.zeros((NAtomsMol,3),dtype=np.float64)
mol2Coord=np.zeros((_N*NAtomsMol,3),dtype=np.float64)
mol1Coord[0]=coord[topol[molInterfaceIndex[i],0]] # extract center molecule
mol1Coord[1]=coord[topol[molInterfaceIndex[i],1]] # extract center molecule
mol1Coord[2]=coord[topol[molInterfaceIndex[i],2]] # extract center molecule
for indexJ,j in enumerate(hNeighbourList[i]):
mol2Coord[0+NAtomsMol*indexJ]=coord[topol[j,0]] # extract neighbors
mol2Coord[1+NAtomsMol*indexJ]=coord[topol[j,1]] # extract neighbors
mol2Coord[2+NAtomsMol*indexJ]=coord[topol[j,2]] # extract neighbors
neighAtoms[i] = len(np.array([index for index in hNeighbourList[i] if index!=-1]))*NAtomsMol # get actual number of neighbor atoms
if is_orig_def:
pass
#acceptorArray[i],donorArray[i],cosAngle[i]=interface_hbonding_orig(mol1Coord,mol2Coord[:neighAtoms[i]],cos_HAngle,cellsize)
#labelArray[i]="D"*np.abs(2-np.sum(donorArray[i]))+"A"*np.clip(np.array([np.sum(acceptorArray[i])]),1,2)[0]
elif is_new_def:
acceptorArray[i],donorArray[i],cosAngle[i]=interface_hbonding_new(mol1Coord,mol2Coord[:neighAtoms[i]],cos_HAngle,cellsize)
labelArray[i]="D"*np.abs(2-np.sum(donorArray[i]))+"A"*np.sum(acceptorArray[i])
#freeOHMask[np.where(cosAngle > 1.0)] = False
del_index=np.argwhere(cosAngle > 1.0)
freeOHCos=np.delete(cosAngle, del_index.flatten())
return acceptorArray, donorArray, labelArray, freeOHCos
molInterfaceIndex=np.array([ 2, 7, 8, 12, 15, 17, 25, 26, 27, 31, 32, 38, 47, 48, 51, 56, 57, 62,
65, 66, 68, 69, 75, 80, 81, 98, 101, 106, 107, 110, 111, 113, 122, 124, 126, 129,
140, 141, 142, 143, 144, 152, 153, 154, 165, 166, 173, 175, 177, 184, 187, 188, 190, 195,
200, 201, 205, 208, 212, 214, 217, 220, 224, 225, 229, 230, 243, 247, 248, 254, 261, 262,
268, 270, 274, 278, 282, 283, 284, 285, 288, 289, 290, 292, 299, 300, 302, 303, 308, 311,
312, 324, 325, 332, 333, 337, 338, 345, 350, 361, 364, 370, 373, 377, 381, 383, 384, 389,
400, 404, 407, 411, 416, 422, 423, 424, 430, 432, 433, 435, 447, 453, 456, 458, 461, 463,
472, 478, 481, 482, 489, 492, 496, 497, 505, 513, 518, 522, 523, 526, 527, 534, 535, 542,
543, 544, 546, 547, 562, 563, 565, 568, 570, 572, 576, 577, 579, 587, 588, 593, 594, 595,
600, 609, 617, 622, 623, 626, 633, 636, 637, 639])
coord=np.loadtxt('coord-test.out', dtype=np.float64, delimiter=",")
hNeighbourList=np.loadtxt('hNeighbourList-test.out', delimiter=",").astype(np.int64)
cos_HAngle=np.cos(50.0*np.pi/180)
cellsize=np.array([26.40,26.40,70])
topol=np.zeros((640,3),dtype=int)
for i in range(640):
for j in range(3):
topol[i,j]=i*3+j
h_types=dict()
interfacialLabels = ['DA', 'DDA', 'DAA']
_,_,labelArray,freeOHCos=freeoh_count_jit(coord,molInterfaceIndex,hNeighbourList,topol,cos_HAngle,cellsize)
hbondDict=Counter(labelArray)
interfacialOH = {k: hbondDict[k] for k in interfacialLabels}
for key, value in interfacialOH.items():
if key in h_types:
h_types[key] += value
else:
h_types[key] = value
print(h_types)
The files coord-test.out
and hNeighbourList-test.out
are attached here as well.