cuda bugfix again

This commit is contained in:
alphataubio
2024-10-22 14:15:02 -04:00
parent 5bdd616bcd
commit 22a15c7cf8
2 changed files with 114 additions and 146 deletions

View File

@ -234,6 +234,7 @@ void FixCMAPKokkos<DeviceType>::post_force(int vflag)
ev_init(eflag,vflag);
copymode = 1;
nlocal = atomKK->nlocal;
Kokkos::parallel_reduce(Kokkos::RangePolicy<DeviceType, TagFixCmapPostForce>(0,ncrosstermlist),*this,ecmap);
copymode = 0;
atomKK->modified(execution_space,F_MASK);
@ -245,27 +246,6 @@ template<class DeviceType>
KOKKOS_INLINE_FUNCTION
void FixCMAPKokkos<DeviceType>::operator()(TagFixCmapPostForce, const int n, double &ecmapKK) const
{
int i1,i2,i3,i4,i5,type;
int li1, li2, mli1,mli2,mli11,mli21,t1,li3,li4,mli3,mli4,mli31,mli41;
// vectors needed to calculate the cross-term dihedral angles
double vb21x,vb21y,vb21z,vb32x,vb32y,vb32z,vb34x,vb34y,vb34z;
double vb23x,vb23y,vb23z;
double vb43x,vb43y,vb43z,vb45x,vb45y,vb45z,a1x,a1y,a1z,b1x,b1y,b1z;
double a2x,a2y,a2z,b2x,b2y,b2z,r32,a1sq,b1sq,a2sq,b2sq,dpr21r32,dpr34r32;
double dpr32r43,dpr45r43,r43,vb12x,vb12y,vb12z;
// cross-term dihedral angles
double phi,psi,phi1,psi1;
double f1[3],f2[3],f3[3],f4[3],f5[3];
double gs[4],d1gs[4],d2gs[4],d12gs[4];
// vectors needed for the gradient/force calculation
double dphidr1x,dphidr1y,dphidr1z,dphidr2x,dphidr2y,dphidr2z;
double dphidr3x,dphidr3y,dphidr3z,dphidr4x,dphidr4y,dphidr4z;
double dpsidr1x,dpsidr1y,dpsidr1z,dpsidr2x,dpsidr2y,dpsidr2z;
double dpsidr3x,dpsidr3y,dpsidr3z,dpsidr4x,dpsidr4y,dpsidr4z;
// Definition of cross-term dihedrals
// phi dihedral
@ -275,14 +255,12 @@ void FixCMAPKokkos<DeviceType>::operator()(TagFixCmapPostForce, const int n, dou
// |--------------------|
// psi dihedral
int nlocal = atomKK->nlocal;
i1 = d_crosstermlist(n,0);
i2 = d_crosstermlist(n,1);
i3 = d_crosstermlist(n,2);
i4 = d_crosstermlist(n,3);
i5 = d_crosstermlist(n,4);
type = d_crosstermlist(n,5);
int i1 = d_crosstermlist(n,0);
int i2 = d_crosstermlist(n,1);
int i3 = d_crosstermlist(n,2);
int i4 = d_crosstermlist(n,3);
int i5 = d_crosstermlist(n,4);
int type = d_crosstermlist(n,5);
if (type == 0) return;
// calculate bond vectors for both dihedrals
@ -290,99 +268,103 @@ void FixCMAPKokkos<DeviceType>::operator()(TagFixCmapPostForce, const int n, dou
// phi
// vb21 = r2 - r1
vb21x = d_x(i2,0) - d_x(i1,0);
vb21y = d_x(i2,1) - d_x(i1,1);
vb21z = d_x(i2,2) - d_x(i1,2);
vb12x = -1.0*vb21x;
vb12y = -1.0*vb21y;
vb12z = -1.0*vb21z;
vb32x = d_x(i3,0) - d_x(i2,0);
vb32y = d_x(i3,1) - d_x(i2,1);
vb32z = d_x(i3,2) - d_x(i2,2);
vb23x = -1.0*vb32x;
vb23y = -1.0*vb32y;
vb23z = -1.0*vb32z;
double vb21x = d_x(i2,0) - d_x(i1,0);
double vb21y = d_x(i2,1) - d_x(i1,1);
double vb21z = d_x(i2,2) - d_x(i1,2);
double vb12x = -1.0*vb21x;
double vb12y = -1.0*vb21y;
double vb12z = -1.0*vb21z;
double vb32x = d_x(i3,0) - d_x(i2,0);
double vb32y = d_x(i3,1) - d_x(i2,1);
double vb32z = d_x(i3,2) - d_x(i2,2);
double vb23x = -1.0*vb32x;
double vb23y = -1.0*vb32y;
double vb23z = -1.0*vb32z;
vb34x = d_x(i3,0) - d_x(i4,0);
vb34y = d_x(i3,1) - d_x(i4,1);
vb34z = d_x(i3,2) - d_x(i4,2);
double vb34x = d_x(i3,0) - d_x(i4,0);
double vb34y = d_x(i3,1) - d_x(i4,1);
double vb34z = d_x(i3,2) - d_x(i4,2);
// psi
// bond vectors same as for phi: vb32
vb43x = -1.0*vb34x;
vb43y = -1.0*vb34y;
vb43z = -1.0*vb34z;
double vb43x = -1.0*vb34x;
double vb43y = -1.0*vb34y;
double vb43z = -1.0*vb34z;
vb45x = d_x(i4,0) - d_x(i5,0);
vb45y = d_x(i4,1) - d_x(i5,1);
vb45z = d_x(i4,2) - d_x(i5,2);
double vb45x = d_x(i4,0) - d_x(i5,0);
double vb45y = d_x(i4,1) - d_x(i5,1);
double vb45z = d_x(i4,2) - d_x(i5,2);
// calculate normal vectors for planes that define the dihedral angles
a1x = vb12y*vb23z - vb12z*vb23y;
a1y = vb12z*vb23x - vb12x*vb23z;
a1z = vb12x*vb23y - vb12y*vb23x;
double a1x = vb12y*vb23z - vb12z*vb23y;
double a1y = vb12z*vb23x - vb12x*vb23z;
double a1z = vb12x*vb23y - vb12y*vb23x;
b1x = vb43y*vb23z - vb43z*vb23y;
b1y = vb43z*vb23x - vb43x*vb23z;
b1z = vb43x*vb23y - vb43y*vb23x;
double b1x = vb43y*vb23z - vb43z*vb23y;
double b1y = vb43z*vb23x - vb43x*vb23z;
double b1z = vb43x*vb23y - vb43y*vb23x;
a2x = vb23y*vb34z - vb23z*vb34y;
a2y = vb23z*vb34x - vb23x*vb34z;
a2z = vb23x*vb34y - vb23y*vb34x;
double a2x = vb23y*vb34z - vb23z*vb34y;
double a2y = vb23z*vb34x - vb23x*vb34z;
double a2z = vb23x*vb34y - vb23y*vb34x;
b2x = vb45y*vb43z - vb45z*vb43y;
b2y = vb45z*vb43x - vb45x*vb43z;
b2z = vb45x*vb43y - vb45y*vb43x;
double b2x = vb45y*vb43z - vb45z*vb43y;
double b2y = vb45z*vb43x - vb45x*vb43z;
double b2z = vb45x*vb43y - vb45y*vb43x;
// calculate terms used later in calculations
r32 = sqrt(vb32x*vb32x + vb32y*vb32y + vb32z*vb32z);
a1sq = a1x*a1x + a1y*a1y + a1z*a1z;
b1sq = b1x*b1x + b1y*b1y + b1z*b1z;
double r32 = sqrt(vb32x*vb32x + vb32y*vb32y + vb32z*vb32z);
double a1sq = a1x*a1x + a1y*a1y + a1z*a1z;
double b1sq = b1x*b1x + b1y*b1y + b1z*b1z;
r43 = sqrt(vb43x*vb43x + vb43y*vb43y + vb43z*vb43z);
a2sq = a2x*a2x + a2y*a2y + a2z*a2z;
b2sq = b2x*b2x + b2y*b2y + b2z*b2z;
double r43 = sqrt(vb43x*vb43x + vb43y*vb43y + vb43z*vb43z);
double a2sq = a2x*a2x + a2y*a2y + a2z*a2z;
double b2sq = b2x*b2x + b2y*b2y + b2z*b2z;
if (a1sq<0.0001 || b1sq<0.0001 || a2sq<0.0001 || b2sq<0.0001) return;
dpr21r32 = vb21x*vb32x + vb21y*vb32y + vb21z*vb32z;
dpr34r32 = vb34x*vb32x + vb34y*vb32y + vb34z*vb32z;
dpr32r43 = vb32x*vb43x + vb32y*vb43y + vb32z*vb43z;
dpr45r43 = vb45x*vb43x + vb45y*vb43y + vb45z*vb43z;
// vectors needed to calculate the cross-term dihedral angles
double dpr21r32 = vb21x*vb32x + vb21y*vb32y + vb21z*vb32z;
double dpr34r32 = vb34x*vb32x + vb34y*vb32y + vb34z*vb32z;
double dpr32r43 = vb32x*vb43x + vb32y*vb43y + vb32z*vb43z;
double dpr45r43 = vb45x*vb43x + vb45y*vb43y + vb45z*vb43z;
// cross-term dihedral angles
// calculate the backbone dihedral angles as VMD and GROMACS
phi = dihedral_angle_atan2(vb21x,vb21y,vb21z,a1x,a1y,a1z,b1x,b1y,b1z,r32);
psi = dihedral_angle_atan2(vb32x,vb32y,vb32z,a2x,a2y,a2z,b2x,b2y,b2z,r43);
double phi = dihedral_angle_atan2(vb21x,vb21y,vb21z,a1x,a1y,a1z,b1x,b1y,b1z,r32);
double psi = dihedral_angle_atan2(vb32x,vb32y,vb32z,a2x,a2y,a2z,b2x,b2y,b2z,r43);
if (phi == 180.0) phi= -180.0;
if (psi == 180.0) psi= -180.0;
phi1 = phi;
double phi1 = phi;
if (phi1 < 0.0) phi1 += 360.0;
psi1 = psi;
double psi1 = psi;
if (psi1 < 0.0) psi1 += 360.0;
// find the neighbor grid point index
li1 = int(((phi1+CMAPXMIN2)/CMAPDX)+((CMAPDIM*1.0)/2.0));
li2 = int(((psi1+CMAPXMIN2)/CMAPDX)+((CMAPDIM*1.0)/2.0));
li3 = int((phi-CMAPXMIN2)/CMAPDX);
li4 = int((psi-CMAPXMIN2)/CMAPDX);
mli3 = li3 % CMAPDIM;
mli4 = li4 % CMAPDIM;
mli31 = (li3+1) % CMAPDIM;
mli41 = (li4+1) %CMAPDIM;
mli1 = li1 % CMAPDIM;
mli2 = li2 % CMAPDIM;
mli11 = (li1+1) % CMAPDIM;
mli21 = (li2+1) %CMAPDIM;
t1 = type-1;
int li1 = int(((phi1+CMAPXMIN2)/CMAPDX)+((CMAPDIM*1.0)/2.0));
int li2 = int(((psi1+CMAPXMIN2)/CMAPDX)+((CMAPDIM*1.0)/2.0));
int li3 = int((phi-CMAPXMIN2)/CMAPDX);
int li4 = int((psi-CMAPXMIN2)/CMAPDX);
int mli3 = li3 % CMAPDIM;
int mli4 = li4 % CMAPDIM;
int mli31 = (li3+1) % CMAPDIM;
int mli41 = (li4+1) %CMAPDIM;
int mli1 = li1 % CMAPDIM;
int mli2 = li2 % CMAPDIM;
int mli11 = (li1+1) % CMAPDIM;
int mli21 = (li2+1) %CMAPDIM;
int t1 = type-1;
if (t1 < 0 || t1 > 5) Kokkos::abort("Invalid CMAP crossterm_type");
// determine the values and derivatives for the grid square points
double gs[4],d1gs[4],d2gs[4],d12gs[4];
gs[0] = d_cmapgrid(t1,mli3,mli4);
gs[1] = d_cmapgrid(t1,mli31,mli4);
gs[2] = d_cmapgrid(t1,mli31,mli41);
@ -417,84 +399,67 @@ void FixCMAPKokkos<DeviceType>::operator()(TagFixCmapPostForce, const int n, dou
// calculate the derivatives dphi/dr_i
dphidr1x = 1.0*r32/a1sq*a1x;
dphidr1y = 1.0*r32/a1sq*a1y;
dphidr1z = 1.0*r32/a1sq*a1z;
double dphidr1x = 1.0*r32/a1sq*a1x;
double dphidr1y = 1.0*r32/a1sq*a1y;
double dphidr1z = 1.0*r32/a1sq*a1z;
dphidr2x = -1.0*r32/a1sq*a1x - dpr21r32/a1sq/r32*a1x + dpr34r32/b1sq/r32*b1x;
dphidr2y = -1.0*r32/a1sq*a1y - dpr21r32/a1sq/r32*a1y + dpr34r32/b1sq/r32*b1y;
dphidr2z = -1.0*r32/a1sq*a1z - dpr21r32/a1sq/r32*a1z + dpr34r32/b1sq/r32*b1z;
double dphidr2x = -1.0*r32/a1sq*a1x - dpr21r32/a1sq/r32*a1x + dpr34r32/b1sq/r32*b1x;
double dphidr2y = -1.0*r32/a1sq*a1y - dpr21r32/a1sq/r32*a1y + dpr34r32/b1sq/r32*b1y;
double dphidr2z = -1.0*r32/a1sq*a1z - dpr21r32/a1sq/r32*a1z + dpr34r32/b1sq/r32*b1z;
dphidr3x = dpr34r32/b1sq/r32*b1x - dpr21r32/a1sq/r32*a1x - r32/b1sq*b1x;
dphidr3y = dpr34r32/b1sq/r32*b1y - dpr21r32/a1sq/r32*a1y - r32/b1sq*b1y;
dphidr3z = dpr34r32/b1sq/r32*b1z - dpr21r32/a1sq/r32*a1z - r32/b1sq*b1z;
double dphidr3x = dpr34r32/b1sq/r32*b1x - dpr21r32/a1sq/r32*a1x - r32/b1sq*b1x;
double dphidr3y = dpr34r32/b1sq/r32*b1y - dpr21r32/a1sq/r32*a1y - r32/b1sq*b1y;
double dphidr3z = dpr34r32/b1sq/r32*b1z - dpr21r32/a1sq/r32*a1z - r32/b1sq*b1z;
dphidr4x = r32/b1sq*b1x;
dphidr4y = r32/b1sq*b1y;
dphidr4z = r32/b1sq*b1z;
double dphidr4x = r32/b1sq*b1x;
double dphidr4y = r32/b1sq*b1y;
double dphidr4z = r32/b1sq*b1z;
// calculate the derivatives dpsi/dr_i
dpsidr1x = 1.0*r43/a2sq*a2x;
dpsidr1y = 1.0*r43/a2sq*a2y;
dpsidr1z = 1.0*r43/a2sq*a2z;
double dpsidr1x = 1.0*r43/a2sq*a2x;
double dpsidr1y = 1.0*r43/a2sq*a2y;
double dpsidr1z = 1.0*r43/a2sq*a2z;
dpsidr2x = r43/a2sq*a2x + dpr32r43/a2sq/r43*a2x - dpr45r43/b2sq/r43*b2x;
dpsidr2y = r43/a2sq*a2y + dpr32r43/a2sq/r43*a2y - dpr45r43/b2sq/r43*b2y;
dpsidr2z = r43/a2sq*a2z + dpr32r43/a2sq/r43*a2z - dpr45r43/b2sq/r43*b2z;
double dpsidr2x = r43/a2sq*a2x + dpr32r43/a2sq/r43*a2x - dpr45r43/b2sq/r43*b2x;
double dpsidr2y = r43/a2sq*a2y + dpr32r43/a2sq/r43*a2y - dpr45r43/b2sq/r43*b2y;
double dpsidr2z = r43/a2sq*a2z + dpr32r43/a2sq/r43*a2z - dpr45r43/b2sq/r43*b2z;
dpsidr3x = dpr45r43/b2sq/r43*b2x - dpr32r43/a2sq/r43*a2x - r43/b2sq*b2x;
dpsidr3y = dpr45r43/b2sq/r43*b2y - dpr32r43/a2sq/r43*a2y - r43/b2sq*b2y;
dpsidr3z = dpr45r43/b2sq/r43*b2z - dpr32r43/a2sq/r43*a2z - r43/b2sq*b2z;
double dpsidr3x = dpr45r43/b2sq/r43*b2x - dpr32r43/a2sq/r43*a2x - r43/b2sq*b2x;
double dpsidr3y = dpr45r43/b2sq/r43*b2y - dpr32r43/a2sq/r43*a2y - r43/b2sq*b2y;
double dpsidr3z = dpr45r43/b2sq/r43*b2z - dpr32r43/a2sq/r43*a2z - r43/b2sq*b2z;
dpsidr4x = r43/b2sq*b2x;
dpsidr4y = r43/b2sq*b2y;
dpsidr4z = r43/b2sq*b2z;
double dpsidr4x = r43/b2sq*b2x;
double dpsidr4y = r43/b2sq*b2y;
double dpsidr4z = r43/b2sq*b2z;
// calculate forces on cross-term atoms: F = -(dE/dPhi)*(dPhi/dr)
f1[0] = dEdPhi*dphidr1x;
f1[1] = dEdPhi*dphidr1y;
f1[2] = dEdPhi*dphidr1z;
f2[0] = dEdPhi*dphidr2x + dEdPsi*dpsidr1x;
f2[1] = dEdPhi*dphidr2y + dEdPsi*dpsidr1y;
f2[2] = dEdPhi*dphidr2z + dEdPsi*dpsidr1z;
f3[0] = -dEdPhi*dphidr3x - dEdPsi*dpsidr2x;
f3[1] = -dEdPhi*dphidr3y - dEdPsi*dpsidr2y;
f3[2] = -dEdPhi*dphidr3z - dEdPsi*dpsidr2z;
f4[0] = -dEdPhi*dphidr4x - dEdPsi*dpsidr3x;
f4[1] = -dEdPhi*dphidr4y - dEdPsi*dpsidr3y;
f4[2] = -dEdPhi*dphidr4z - dEdPsi*dpsidr3z;
f5[0] = -dEdPsi*dpsidr4x;
f5[1] = -dEdPsi*dpsidr4y;
f5[2] = -dEdPsi*dpsidr4z;
// apply force to each of the 5 atoms
if (i1 < nlocal) {
d_f(i1,0) += f1[0];
d_f(i1,1) += f1[1];
d_f(i1,2) += f1[2];
d_f(i1,0) += dEdPhi*dphidr1x;
d_f(i1,1) += dEdPhi*dphidr1y;
d_f(i1,2) += dEdPhi*dphidr1z;
}
if (i2 < nlocal) {
d_f(i2,0) += f2[0];
d_f(i2,1) += f2[1];
d_f(i2,2) += f2[2];
d_f(i2,0) += dEdPhi*dphidr2x + dEdPsi*dpsidr1x;
d_f(i2,1) += dEdPhi*dphidr2y + dEdPsi*dpsidr1y;
d_f(i2,2) += dEdPhi*dphidr2z + dEdPsi*dpsidr1z;
}
if (i3 < nlocal) {
d_f(i3,0) += f3[0];
d_f(i3,1) += f3[1];
d_f(i3,2) += f3[2];
d_f(i3,0) += (-dEdPhi*dphidr3x - dEdPsi*dpsidr2x);
d_f(i3,1) += (-dEdPhi*dphidr3y - dEdPsi*dpsidr2y);
d_f(i3,2) += (-dEdPhi*dphidr3z - dEdPsi*dpsidr2z);
}
if (i4 < nlocal) {
d_f(i4,0) += f4[0];
d_f(i4,1) += f4[1];
d_f(i4,2) += f4[2];
d_f(i4,0) += (-dEdPhi*dphidr4x - dEdPsi*dpsidr3x);
d_f(i4,1) += (-dEdPhi*dphidr4y - dEdPsi*dpsidr3y);
d_f(i4,2) += (-dEdPhi*dphidr4z - dEdPsi*dpsidr3z);
}
if (i5 < nlocal) {
d_f(i5,0) += f5[0];
d_f(i5,1) += f5[1];
d_f(i5,2) += f5[2];
d_f(i5,0) -= dEdPsi*dpsidr4x;
d_f(i5,1) -= dEdPsi*dpsidr4y;
d_f(i5,2) -= dEdPsi*dpsidr4z;
}
}

View File

@ -69,6 +69,9 @@ class FixCMAPKokkos : public FixCMAP, public KokkosBase {
ExecutionSpace space) override;
protected:
int nlocal;
typename AT::t_x_array d_x;
typename AT::t_f_array d_f;