Skip to content

Commit

Permalink
Improvement: remove abuse use of einsum, accelerate pcm and int1e 2nd…
Browse files Browse the repository at this point in the history
… derivative kernels
  • Loading branch information
henryw7 committed Jan 27, 2025
1 parent fc9a413 commit a55e390
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 67 deletions.
6 changes: 3 additions & 3 deletions gpu4pyscf/gto/int3c1e_ipip.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def get_int3c1e_ipip1_charge_contracted(mol, grids, charge_exponents, charges, i
ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
n_charge_sum_per_thread = 10
n_charge_sum_per_thread = 100 # This number roughly optimize kernel performance on a large system

int1e_angular_slice = cp.zeros([3, 3, j1-j0, i1-i0], order='C')

Expand Down Expand Up @@ -145,7 +145,7 @@ def get_int3c1e_ipvip1_charge_contracted(mol, grids, charge_exponents, charges,
ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
n_charge_sum_per_thread = 10
n_charge_sum_per_thread = 100 # This number roughly optimize kernel performance on a large system

int1e_angular_slice = cp.zeros([3, 3, j1-j0, i1-i0], order='C')

Expand Down Expand Up @@ -220,7 +220,7 @@ def get_int3c1e_ip1ip2_charge_contracted(mol, grids, charge_exponents, charges,
ngrids = grids.shape[0]
# n_charge_sum_per_thread = 1 # means every thread processes one pair and one grid
# n_charge_sum_per_thread = ngrids # or larger number gaurantees one thread processes one pair and all grid points
n_charge_sum_per_thread = 10
n_charge_sum_per_thread = 100 # This number roughly optimize kernel performance on a large system

int1e_angular_slice = cp.zeros([3, 3, j1-j0, i1-i0], order='C')

Expand Down
62 changes: 19 additions & 43 deletions gpu4pyscf/lib/solvent/pcm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ static void _pcm_d2D_d2S(double *matrix_d2D, double *matrix_d2S,
const double dy = coords[3*i+1] - coords[3*j+1];
const double dz = coords[3*i+2] - coords[3*j+2];
const double rij = norm3d(dx, dy, dz);
const double rij_1 = 1.0 / rij;
const double rij_1 = (i != j) ? (1.0 / rij) : 0.0; // This guarantees that if i == j, all matrix elements = 0
const double rij_2 = rij_1 * rij_1;
const double rij_3 = rij_2 * rij_1;
const double rij_4 = rij_2 * rij_2;
Expand All @@ -171,27 +171,15 @@ static void _pcm_d2D_d2S(double *matrix_d2D, double *matrix_d2S,
const double S_xyz_diagonal_prefactor = two_eij_over_sqrt_pi_exp_minus_eij2_rij2 * rij_2 - rij_3 * erf_eij_rij;

const int n2 = n * n;
if (i == j) {
matrix_d2S[i*n + j ] = 0.0;
matrix_d2S[i*n + j + n2 ] = 0.0;
matrix_d2S[i*n + j + n2 * 2] = 0.0;
matrix_d2S[i*n + j + n2 * 3] = 0.0;
matrix_d2S[i*n + j + n2 * 4] = 0.0;
matrix_d2S[i*n + j + n2 * 5] = 0.0;
matrix_d2S[i*n + j + n2 * 6] = 0.0;
matrix_d2S[i*n + j + n2 * 7] = 0.0;
matrix_d2S[i*n + j + n2 * 8] = 0.0;
} else {
matrix_d2S[i*n + j ] = dx * dx * S_direct_product_prefactor + S_xyz_diagonal_prefactor;
matrix_d2S[i*n + j + n2 ] = dx * dy * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 2] = dx * dz * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 3] = dy * dx * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 4] = dy * dy * S_direct_product_prefactor + S_xyz_diagonal_prefactor;
matrix_d2S[i*n + j + n2 * 5] = dy * dz * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 6] = dz * dx * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 7] = dz * dy * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 8] = dz * dz * S_direct_product_prefactor + S_xyz_diagonal_prefactor;
}
matrix_d2S[i*n + j ] = dx * dx * S_direct_product_prefactor + S_xyz_diagonal_prefactor;
matrix_d2S[i*n + j + n2 ] = dx * dy * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 2] = dx * dz * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 3] = dy * dx * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 4] = dy * dy * S_direct_product_prefactor + S_xyz_diagonal_prefactor;
matrix_d2S[i*n + j + n2 * 5] = dy * dz * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 6] = dz * dx * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 7] = dz * dy * S_direct_product_prefactor;
matrix_d2S[i*n + j + n2 * 8] = dz * dz * S_direct_product_prefactor + S_xyz_diagonal_prefactor;

if (matrix_d2D != NULL) {
const double nxj = norm_vec[3*j];
Expand All @@ -205,27 +193,15 @@ static void _pcm_d2D_d2S(double *matrix_d2D, double *matrix_d2S,

const double D_direct_product_prefactor = (-two_eij_over_sqrt_pi_exp_minus_eij2_rij2 * (15 * rij_6 + 10 * eij2 * rij_4 + 4 * eij4 * rij_2)
+ 15 * rij_7 * erf_eij_rij) * nj_rij;
if (i == j) {
matrix_d2D[i*n + j ] = 0.0;
matrix_d2D[i*n + j + n2 ] = 0.0;
matrix_d2D[i*n + j + n2 * 2] = 0.0;
matrix_d2D[i*n + j + n2 * 3] = 0.0;
matrix_d2D[i*n + j + n2 * 4] = 0.0;
matrix_d2D[i*n + j + n2 * 5] = 0.0;
matrix_d2D[i*n + j + n2 * 6] = 0.0;
matrix_d2D[i*n + j + n2 * 7] = 0.0;
matrix_d2D[i*n + j + n2 * 8] = 0.0;
} else {
matrix_d2D[i*n + j ] = D_direct_product_prefactor * dx * dx - S_direct_product_prefactor * (dx * nxj + dx * nxj + nj_rij);
matrix_d2D[i*n + j + n2 ] = D_direct_product_prefactor * dx * dy - S_direct_product_prefactor * (dy * nxj + dx * nyj);
matrix_d2D[i*n + j + n2 * 2] = D_direct_product_prefactor * dx * dz - S_direct_product_prefactor * (dz * nxj + dx * nzj);
matrix_d2D[i*n + j + n2 * 3] = D_direct_product_prefactor * dy * dx - S_direct_product_prefactor * (dx * nyj + dy * nxj);
matrix_d2D[i*n + j + n2 * 4] = D_direct_product_prefactor * dy * dy - S_direct_product_prefactor * (dy * nyj + dy * nyj + nj_rij);
matrix_d2D[i*n + j + n2 * 5] = D_direct_product_prefactor * dy * dz - S_direct_product_prefactor * (dz * nyj + dy * nzj);
matrix_d2D[i*n + j + n2 * 6] = D_direct_product_prefactor * dz * dx - S_direct_product_prefactor * (dx * nzj + dz * nxj);
matrix_d2D[i*n + j + n2 * 7] = D_direct_product_prefactor * dz * dy - S_direct_product_prefactor * (dy * nzj + dz * nyj);
matrix_d2D[i*n + j + n2 * 8] = D_direct_product_prefactor * dz * dz - S_direct_product_prefactor * (dz * nzj + dz * nzj + nj_rij);
}
matrix_d2D[i*n + j ] = D_direct_product_prefactor * dx * dx - S_direct_product_prefactor * (dx * nxj + dx * nxj + nj_rij);
matrix_d2D[i*n + j + n2 ] = D_direct_product_prefactor * dx * dy - S_direct_product_prefactor * (dy * nxj + dx * nyj);
matrix_d2D[i*n + j + n2 * 2] = D_direct_product_prefactor * dx * dz - S_direct_product_prefactor * (dz * nxj + dx * nzj);
matrix_d2D[i*n + j + n2 * 3] = D_direct_product_prefactor * dy * dx - S_direct_product_prefactor * (dx * nyj + dy * nxj);
matrix_d2D[i*n + j + n2 * 4] = D_direct_product_prefactor * dy * dy - S_direct_product_prefactor * (dy * nyj + dy * nyj + nj_rij);
matrix_d2D[i*n + j + n2 * 5] = D_direct_product_prefactor * dy * dz - S_direct_product_prefactor * (dz * nyj + dy * nzj);
matrix_d2D[i*n + j + n2 * 6] = D_direct_product_prefactor * dz * dx - S_direct_product_prefactor * (dx * nzj + dz * nxj);
matrix_d2D[i*n + j + n2 * 7] = D_direct_product_prefactor * dz * dy - S_direct_product_prefactor * (dy * nzj + dz * nyj);
matrix_d2D[i*n + j + n2 * 8] = D_direct_product_prefactor * dz * dz - S_direct_product_prefactor * (dz * nzj + dz * nzj + nj_rij);
}
}

Expand Down
42 changes: 21 additions & 21 deletions gpu4pyscf/solvent/hessian/pcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ def get_d2Sii(surface, dF, d2F):
switch_fun = surface['switch_fun']
dF = dF.transpose(2,0,1)
dF_dF = dF[:, cupy.newaxis, :, cupy.newaxis, :] * dF[cupy.newaxis, :, cupy.newaxis, :, :]
dF_dF_over_F3 = cupy.einsum('ABdDq,q->ABdDq', dF_dF, 1.0/(switch_fun**3))
d2F_over_F2 = cupy.einsum('ABdDq,q->ABdDq', d2F, 1.0/(switch_fun**2))
dF_dF_over_F3 = dF_dF * (1.0/(switch_fun**3))
d2F_over_F2 = d2F * (1.0/(switch_fun**2))
d2Sii = 2 * dF_dF_over_F3 - d2F_over_F2
d2Sii = (2.0/PI)**0.5 * cupy.einsum('ABdDq,q->ABdDq', d2Sii, charge_exp)
d2Sii = (2.0/PI)**0.5 * (d2Sii * charge_exp)
return d2Sii

def get_d2D_d2S(surface, with_S=True, with_D=False, stream=None):
Expand Down Expand Up @@ -193,17 +193,17 @@ def analytical_hess_nuc(pcmobj, dm, verbose=None):
# # This causes severe numerical problem in function int2c2e_ip1ip2, and make the main diagonal of hessian garbage.
# d2I_dA2 = gto.mole.intor_cross(int2c2e_ipip1, fakemol_nuc, fakemol)
d2I_dA2 = -gto.mole.intor_cross(int2c2e_ip1ip2, fakemol_nuc, fakemol)
d2I_dA2 = numpy.einsum('dAq,q->dA', d2I_dA2, q_sym)
d2I_dA2 = d2I_dA2 @ q_sym
d2I_dA2 = d2I_dA2.reshape(3, 3, mol.natm)
for i_atom in range(mol.natm):
d2e_from_d2I[i_atom, i_atom, :, :] += atom_charges[i_atom] * d2I_dA2[:, :, i_atom]

d2I_dC2 = gto.mole.intor_cross(int2c2e_ipip1, fakemol, fakemol_nuc)
d2I_dC2 = numpy.einsum('dqA,A->dq', d2I_dC2, atom_charges)
d2I_dC2 = d2I_dC2 @ atom_charges
d2I_dC2 = d2I_dC2.reshape(3, 3, ngrids)
for i_atom in range(mol.natm):
g0,g1 = gridslice[i_atom]
d2e_from_d2I[i_atom, i_atom, :, :] += numpy.einsum('dDq,q->dD', d2I_dC2[:, :, g0:g1], q_sym[g0:g1])
d2e_from_d2I[i_atom, i_atom, :, :] += d2I_dC2[:, :, g0:g1] @ q_sym[g0:g1]

intopt_derivative = int3c1e.VHFOpt(mol)
intopt_derivative.build(cutoff = 1e-14, aosym = False)
Expand Down Expand Up @@ -291,7 +291,7 @@ def analytical_hess_qv(pcmobj, dm, verbose=None):
d2I_dC2 = int1e_grids_ipip2(mol, grid_coords, dm = dm, intopt = intopt_derivative, charge_exponents = charge_exp**2)
for i_atom in range(mol.natm):
g0,g1 = gridslice[i_atom]
d2e_from_d2I[i_atom, i_atom, :, :] += cupy.einsum('dDq,q->dD', d2I_dC2[:, :, g0:g1], q_sym[g0:g1])
d2e_from_d2I[i_atom, i_atom, :, :] += d2I_dC2[:, :, g0:g1] @ q_sym[g0:g1]

dqdx = get_dqsym_dx(pcmobj, dm, range(mol.natm), intopt_derivative)

Expand Down Expand Up @@ -320,8 +320,8 @@ def get_dS_dot_q(dS, dSii, q, atmlst, gridslice):
output = cupy.einsum('diA,i->Adi', dSii[:,:,atmlst], q)
for i_atom in atmlst:
g0,g1 = gridslice[i_atom]
output[i_atom, :, g0:g1] += cupy.einsum('dij,j->di', dS[:,g0:g1,:], q)
output[i_atom, :, :] -= cupy.einsum('dij,j->di', dS[:,:,g0:g1], q[g0:g1])
output[i_atom, :, g0:g1] += dS[:,g0:g1,:] @ q
output[i_atom, :, :] -= dS[:,:,g0:g1] @ q[g0:g1]
return output
def get_dST_dot_q(dS, dSii, q, atmlst, gridslice):
# S is symmetric
Expand All @@ -334,19 +334,19 @@ def get_dD_dot_q(dD, q, atmlst, gridslice, ngrids):
output = cupy.zeros([len(atmlst), 3, ngrids])
for i_atom in atmlst:
g0,g1 = gridslice[i_atom]
output[i_atom, :, g0:g1] += cupy.einsum('dij,j->di', dD[:,g0:g1,:], q)
output[i_atom, :, :] -= cupy.einsum('dij,j->di', dD[:,:,g0:g1], q[g0:g1])
output[i_atom, :, g0:g1] += dD[:,g0:g1,:] @ q
output[i_atom, :, :] -= dD[:,:,g0:g1] @ q[g0:g1]
return output
def get_dDT_dot_q(dD, q, atmlst, gridslice, ngrids):
return get_dD_dot_q(-dD.transpose(0,2,1), q, atmlst, gridslice, ngrids)

def get_v_dot_d2S_dot_q(d2S, d2Sii, v_left, q_right, natom, gridslice):
output = cupy.einsum('ABdDq,q->ABdD', d2Sii, v_left * q_right)
output = d2Sii @ (v_left * q_right)
for i_atom in range(natom):
gi0,gi1 = gridslice[i_atom]
for j_atom in range(natom):
gj0,gj1 = gridslice[j_atom]
d2S_atom_ij = cupy.einsum('q,dDqQ,Q->dD', v_left[gi0:gi1], d2S[:,:,gi0:gi1,gj0:gj1], q_right[gj0:gj1])
d2S_atom_ij = cupy.einsum('q,dDq->dD', v_left[gi0:gi1], d2S[:,:,gi0:gi1,gj0:gj1] @ q_right[gj0:gj1])
output[i_atom, i_atom, :, :] += d2S_atom_ij
output[j_atom, j_atom, :, :] += d2S_atom_ij
output[i_atom, j_atom, :, :] -= d2S_atom_ij
Expand All @@ -357,15 +357,15 @@ def get_v_dot_d2ST_dot_q(d2S, d2Sii, v_left, q_right, natom, gridslice):
return get_v_dot_d2S_dot_q(d2S, d2Sii, v_left, q_right, natom, gridslice)

def get_v_dot_d2A_dot_q(d2A, v_left, q_right):
return cupy.einsum('ABdDq,q->ABdD', d2A, v_left * q_right)
return d2A @ (v_left * q_right)

def get_v_dot_d2D_dot_q(d2D, v_left, q_right, natom, gridslice):
output = cupy.zeros([natom, natom, 3, 3])
for i_atom in range(natom):
gi0,gi1 = gridslice[i_atom]
for j_atom in range(natom):
gj0,gj1 = gridslice[j_atom]
d2D_atom_ij = cupy.einsum('q,dDqQ,Q->dD', v_left[gi0:gi1], d2D[:,:,gi0:gi1,gj0:gj1], q_right[gj0:gj1])
d2D_atom_ij = cupy.einsum('q,dDq->dD', v_left[gi0:gi1], d2D[:,:,gi0:gi1,gj0:gj1] @ q_right[gj0:gj1])
output[i_atom, i_atom, :, :] += d2D_atom_ij
output[j_atom, j_atom, :, :] += d2D_atom_ij
output[i_atom, j_atom, :, :] -= d2D_atom_ij
Expand Down Expand Up @@ -476,10 +476,10 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
vK_1_d2K_q += get_v_dot_d2A_dot_q(d2A, (D.T @ vK_1).T, S @ q)
vK_1_d2K_q += get_v_dot_d2S_dot_q(d2S, d2Sii, (DA.T @ vK_1).T, q, natom, gridslice)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dDdx, dAdx_dot_Sq)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->ABdD', vK_1_dot_dDdx, A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dDdx * A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1D_dot_dAdx, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dDdx, dAdx_dot_Sq)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->BADd', vK_1_dot_dDdx, A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dDdx * A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1D_dot_dAdx, dSdx_dot_q)
vK_1_d2K_q *= -f_eps_over_2pi
vK_1_d2K_q += get_v_dot_d2S_dot_q(d2S, d2Sii, vK_1, q, natom, gridslice)
Expand Down Expand Up @@ -572,16 +572,16 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
vK_1_d2K_q += get_v_dot_d2A_dot_q(d2A, (S @ vK_1).T, D.T @ q)
vK_1_d2K_q += get_v_dot_d2ST_dot_q(d2S, d2Sii, vK_1, DA.T @ q, natom, gridslice)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dDdx, dAdx_dot_Sq)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->ABdD', vK_1_dot_dDdx, A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dDdx * A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1D_dot_dAdx, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dSdxT, dAdxT_dot_DT_q)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->ABdD', vK_1_dot_dSdxT, A, dDdxT_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dSdxT * A, dDdxT_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->ABdD', vK_1_ST_dot_dAdxT, dDdxT_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dDdx, dAdx_dot_Sq)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->BADd', vK_1_dot_dDdx, A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dDdx * A, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1D_dot_dAdx, dSdx_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dSdxT, dAdxT_dot_DT_q)
vK_1_d2K_q += cupy.einsum('Adi,i,BDi->BADd', vK_1_dot_dSdxT, A, dDdxT_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_dot_dSdxT * A, dDdxT_dot_q)
vK_1_d2K_q += cupy.einsum('Adi,BDi->BADd', vK_1_ST_dot_dAdxT, dDdxT_dot_q)
vK_1_d2K_q *= -f_eps_over_4pi
vK_1_d2K_q += get_v_dot_d2S_dot_q(d2S, d2Sii, vK_1, q, natom, gridslice)
Expand Down

0 comments on commit a55e390

Please sign in to comment.