Skip to content

Commit

Permalink
Remove explicit inverse K
Browse files Browse the repository at this point in the history
  • Loading branch information
henryw7 committed Jan 23, 2025
1 parent 69b5724 commit fc9a413
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions gpu4pyscf/solvent/hessian/pcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ def analytical_hess_qv(pcmobj, dm, verbose=None):
t1 = log.timer_debug1('solvent hessian d(dI/dx * q)/dx contribution', *t1)
return d2e

def einsum_ij_Adj_Adi_inverseK(K, Adj_term):
nA, nd, nj = Adj_term.shape
# return cupy.einsum('ij,Adj->Adi', cupy.linalg.inv(K), Adj_term)
return cupy.linalg.solve(K, Adj_term.reshape(nA * nd, nj).T).T.reshape(nA, nd, nj)
def einsum_Adi_ij_Adj_inverseK(Adi_term, K):
nA, nd, nj = Adi_term.shape
# return cupy.einsum('Adi,ij->Adj', Adi_term, cupy.linalg.inv(K))
return cupy.linalg.solve(K.T, Adi_term.reshape(nA * nd, nj).T).T.reshape(nA, nd, nj)

def get_dS_dot_q(dS, dSii, q, atmlst, gridslice):
output = cupy.einsum('diA,i->Adi', dSii[:,:,atmlst], q)
for i_atom in atmlst:
Expand Down Expand Up @@ -392,8 +401,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):

ngrids = q.shape[0]

inverse_K = cupy.linalg.inv(K)
vK_1 = inverse_K.T @ v_grids
vK_1 = cupy.linalg.solve(K.T, v_grids)

if pcmobj.method.upper() in ['C-PCM', 'CPCM', 'COSMO']:
_, dS = get_dD_dS(pcmobj.surface, with_D=False, with_S=True)
Expand All @@ -409,14 +417,14 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
# d(S-1 R) = - S-1 dS S-1 R
# d2(S-1 R) = (S-1 dS S-1 dS S-1 R) + (S-1 dS S-1 dS S-1 R) - (S-1 d2S S-1 R)
dSdx_dot_q = get_dS_dot_q(dS, dSii, q, atmlst, gridslice)
S_1_dSdx_dot_q = cupy.einsum('ij,Adj->Adi', inverse_K, dSdx_dot_q)
S_1_dSdx_dot_q = einsum_ij_Adj_Adi_inverseK(K, dSdx_dot_q)
VS_1_dot_dSdx = get_dST_dot_q(dS, dSii, vK_1, atmlst, gridslice)
d2e_from_d2KR = cupy.einsum('Adi,BDi->ABdD', VS_1_dot_dSdx, S_1_dSdx_dot_q) * 2

d2e_from_d2KR -= get_v_dot_d2S_dot_q(d2S, d2Sii, vK_1, q, natom, gridslice)

dK_1Rv = -cupy.einsum("ij,Adj->Adi", inverse_K, dSdx_dot_q)
dvK_1R = -cupy.einsum("Adi,ij->Adj", VS_1_dot_dSdx, inverse_K @ R)
dK_1Rv = -S_1_dSdx_dot_q
dvK_1R = -einsum_Adi_ij_Adj_inverseK(VS_1_dot_dSdx, K) @ R

elif pcmobj.method.upper() in ['IEF-PCM', 'IEFPCM', 'SMD']:
dD, dS = get_dD_dS(pcmobj.surface, with_D=True, with_S=True)
Expand Down Expand Up @@ -450,7 +458,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
dDdx_dot_ASq = get_dD_dot_q(dD, AS @ q, atmlst, gridslice, ngrids)
dKdx_dot_q -= f_eps_over_2pi * dDdx_dot_ASq

K_1_dot_dKdx_dot_q = cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)
K_1_dot_dKdx_dot_q = einsum_ij_Adj_Adi_inverseK(K, dKdx_dot_q)

vK_1_dot_dSdx = get_dST_dot_q(dS, dSii, vK_1, atmlst, gridslice)
vK_1_dot_dKdx = vK_1_dot_dSdx
Expand Down Expand Up @@ -482,7 +490,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice, ngrids)
dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))

K_1_dot_dRdx_dot_V = cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)
K_1_dot_dRdx_dot_V = einsum_ij_Adj_Adi_inverseK(K, dRdx_dot_V)

d2e_from_d2KR -= cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dKdx, K_1_dot_dRdx_dot_V)
d2e_from_d2KR -= cupy.einsum('Adi,BDi->BADd', vK_1_dot_dKdx, K_1_dot_dRdx_dot_V)
Expand All @@ -501,7 +509,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
VK_1_dot_dDdx = get_dDT_dot_q(dD, vK_1, atmlst, gridslice, ngrids)
VK_1_dot_dRdx = f_eps_over_2pi * (VK_1D_dot_dAdx + VK_1_dot_dDdx * A)

dvK_1R = -cupy.einsum("Adi,ij->Adj", vK_1_dot_dKdx, inverse_K @ R) + VK_1_dot_dRdx
dvK_1R = -einsum_Adi_ij_Adj_inverseK(vK_1_dot_dKdx, K) @ R + VK_1_dot_dRdx

elif pcmobj.method.upper() in ['SS(V)PE']:
dD, dS = get_dD_dS(pcmobj.surface, with_D=True, with_S=True)
Expand Down Expand Up @@ -537,7 +545,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
dSdxT_dot_AT_DT_q = get_dS_dot_q(dS, dSii, DA.T @ q, atmlst, gridslice)
dKdx_dot_q -= f_eps_over_4pi * dSdxT_dot_AT_DT_q

K_1_dot_dKdx_dot_q = cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)
K_1_dot_dKdx_dot_q = einsum_ij_Adj_Adi_inverseK(K, dKdx_dot_q)

vK_1_dot_dSdx = get_dST_dot_q(dS, dSii, vK_1, atmlst, gridslice)
vK_1_dot_dKdx = vK_1_dot_dSdx
Expand Down Expand Up @@ -584,7 +592,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice, ngrids)
dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))

K_1_dot_dRdx_dot_V = cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)
K_1_dot_dRdx_dot_V = einsum_ij_Adj_Adi_inverseK(K, dRdx_dot_V)

d2e_from_d2KR -= cupy.einsum('Adi,BDi->ABdD', vK_1_dot_dKdx, K_1_dot_dRdx_dot_V)
d2e_from_d2KR -= cupy.einsum('Adi,BDi->BADd', vK_1_dot_dKdx, K_1_dot_dRdx_dot_V)
Expand All @@ -603,7 +611,7 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
VK_1_dot_dDdx = get_dDT_dot_q(dD, vK_1, atmlst, gridslice, ngrids)
VK_1_dot_dRdx = f_eps_over_2pi * (VK_1D_dot_dAdx + VK_1_dot_dDdx * A)

dvK_1R = -cupy.einsum("Adi,ij->Adj", vK_1_dot_dKdx, inverse_K @ R) + VK_1_dot_dRdx
dvK_1R = -einsum_Adi_ij_Adj_inverseK(vK_1_dot_dKdx, K) @ R + VK_1_dot_dRdx

else:
raise RuntimeError(f"Unknown implicit solvent model: {pcmobj.method}")
Expand All @@ -622,15 +630,17 @@ def analytical_hess_solver(pcmobj, dm, verbose=None):
t1 = log.timer_debug1('solvent hessian d(V * dK-1R/dx * V)/dx contribution', *t1)
return d2e

def get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K):
def get_dqsym_dx_fix_vgrids(pcmobj, atmlst):
assert pcmobj._intermediates is not None

gridslice = pcmobj.surface['gslice_by_atom']
v_grids = pcmobj._intermediates['v_grids']
A = pcmobj._intermediates['A']
D = pcmobj._intermediates['D']
S = pcmobj._intermediates['S']
K = pcmobj._intermediates['K']
R = pcmobj._intermediates['R']
q = pcmobj._intermediates['q']
q_sym = pcmobj._intermediates['q_sym']
f_epsilon = pcmobj._intermediates['f_epsilon']

Expand All @@ -645,7 +655,7 @@ def get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K):
# dR = 0, dK = dS
dSdx_dot_q = get_dS_dot_q(dS, dSii, q_sym, atmlst, gridslice)

dqdx_fix_Vq = cupy.einsum('ij,Adj->Adi', inverse_K, dSdx_dot_q)
dqdx_fix_Vq = einsum_ij_Adj_Adi_inverseK(K, dSdx_dot_q)

elif pcmobj.method.upper() in ['IEF-PCM', 'IEFPCM', 'SMD']:
dF, dA = get_dF_dA(pcmobj.surface)
Expand All @@ -658,7 +668,6 @@ def get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K):
# dK = dS - f_eps/(2*pi) * (dD*A*S + D*dA*S + D*A*dS)
f_eps_over_2pi = f_epsilon/(2.0*PI)

q = inverse_K @ R @ v_grids
dSdx_dot_q = get_dS_dot_q(dS, dSii, q, atmlst, gridslice)

DA = D*A
Expand All @@ -671,16 +680,16 @@ def get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K):
dDdx_dot_ASq = get_dD_dot_q(dD, AS @ q, atmlst, gridslice, ngrids)
dKdx_dot_q -= f_eps_over_2pi * dDdx_dot_ASq

dqdx_fix_Vq = -cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)
dqdx_fix_Vq = -einsum_ij_Adj_Adi_inverseK(K, dKdx_dot_q)

dAdx_dot_V = get_dA_dot_q(dA, v_grids, atmlst)

dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice, ngrids)

dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))
dqdx_fix_Vq += cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)
dqdx_fix_Vq += einsum_ij_Adj_Adi_inverseK(K, dRdx_dot_V)

invKT_V = inverse_K.T @ v_grids
invKT_V = cupy.linalg.solve(K.T, v_grids)
dDdxT_dot_invKT_V = get_dDT_dot_q(dD, invKT_V, atmlst, gridslice, ngrids)

DT_invKT_V = D.T @ invKT_V
Expand All @@ -695,8 +704,9 @@ def get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K):

dSdxT_dot_AT_DT_invKT_V = get_dST_dot_q(dS, dSii, DA.T @ invKT_V, atmlst, gridslice)
dKdxT_dot_invKT_V -= f_eps_over_2pi * dSdxT_dot_AT_DT_invKT_V
invKT_dKdxT_dot_invKT_V = einsum_ij_Adj_Adi_inverseK(K.T, dKdxT_dot_invKT_V)

dqdx_fix_Vq += -cupy.einsum('ij,Adj->Adi', R.T @ inverse_K.T, dKdxT_dot_invKT_V)
dqdx_fix_Vq += -cupy.einsum('ij,Adj->Adi', R.T, invKT_dKdxT_dot_invKT_V)

dqdx_fix_Vq *= -0.5

Expand Down Expand Up @@ -735,26 +745,27 @@ def dK_dot_q(q):

f_eps_over_2pi = f_epsilon/(2.0*PI)

q = inverse_K @ R @ v_grids
dKdx_dot_q = dK_dot_q(q)
dqdx_fix_Vq = -cupy.einsum('ij,Adj->Adi', inverse_K, dKdx_dot_q)
dqdx_fix_Vq = -einsum_ij_Adj_Adi_inverseK(K, dKdx_dot_q)

dAdx_dot_V = get_dA_dot_q(dA, v_grids, atmlst)

dDdx_dot_AV = get_dD_dot_q(dD, A * v_grids, atmlst, gridslice, ngrids)

dRdx_dot_V = f_eps_over_2pi * (dDdx_dot_AV + cupy.einsum('ij,Adj->Adi', D, dAdx_dot_V))
dqdx_fix_Vq += cupy.einsum('ij,Adj->Adi', inverse_K, dRdx_dot_V)
dqdx_fix_Vq += einsum_ij_Adj_Adi_inverseK(K, dRdx_dot_V)

invKT_V = inverse_K.T @ v_grids
invKT_V = cupy.linalg.solve(K.T, v_grids)
dDdxT_dot_invKT_V = get_dDT_dot_q(dD, invKT_V, atmlst, gridslice, ngrids)

DT_invKT_V = D.T @ invKT_V
dAdxT_dot_DT_invKT_V = get_dA_dot_q(dA, DT_invKT_V, atmlst)
dqdx_fix_Vq += f_eps_over_2pi * (cupy.einsum('i,Adi->Adi', A, dDdxT_dot_invKT_V) + dAdxT_dot_DT_invKT_V)

dKdx_dot_invKT_V = dK_dot_q(invKT_V)
dqdx_fix_Vq += -cupy.einsum('ij,Adj->Adi', R.T @ inverse_K.T, dKdx_dot_invKT_V)
invKT_dKdx_dot_invKT_V = einsum_ij_Adj_Adi_inverseK(K.T, dKdx_dot_invKT_V)

dqdx_fix_Vq += -cupy.einsum('ij,Adj->Adi', R.T, invKT_dKdx_dot_invKT_V)

dqdx_fix_Vq *= -0.5

Expand Down Expand Up @@ -798,18 +809,20 @@ def get_dvgrids(pcmobj, dm, atmlst, intopt_derivative):

return dV_on_charge_dx

def get_dqsym_dx_fix_K_R(pcmobj, dm, atmlst, inverse_K, intopt_derivative):
def get_dqsym_dx_fix_K_R(pcmobj, dm, atmlst, intopt_derivative):
dV_on_charge_dx = get_dvgrids(pcmobj, dm, atmlst, intopt_derivative)
K = pcmobj._intermediates['K']
R = pcmobj._intermediates['R']
KR_symmetrized = 0.5 * (inverse_K @ R + R.T @ inverse_K.T)
dqdx_fix_K_R = cupy.einsum('ij,Adj->Adi', KR_symmetrized, dV_on_charge_dx)
R_dVdx = cupy.einsum('ij,Adj->Adi', R, dV_on_charge_dx)
K_1_R_dVdx = einsum_ij_Adj_Adi_inverseK(K, R_dVdx)
K_1T_dVdx = einsum_ij_Adj_Adi_inverseK(K.T, dV_on_charge_dx)
RT_K_1T_dVdx = cupy.einsum('ij,Adj->Adi', R.T, K_1T_dVdx)
dqdx_fix_K_R = 0.5 * (K_1_R_dVdx + RT_K_1T_dVdx)

return dqdx_fix_K_R

def get_dqsym_dx(pcmobj, dm, atmlst, intopt_derivative):
K = pcmobj._intermediates['K']
inverse_K = cupy.linalg.inv(K)
return get_dqsym_dx_fix_vgrids(pcmobj, atmlst, inverse_K) + get_dqsym_dx_fix_K_R(pcmobj, dm, atmlst, inverse_K, intopt_derivative)
return get_dqsym_dx_fix_vgrids(pcmobj, atmlst) + get_dqsym_dx_fix_K_R(pcmobj, dm, atmlst, intopt_derivative)

def analytical_grad_vmat(pcmobj, dm, mo_coeff, mo_occ, atmlst=None, verbose=None):
'''
Expand Down

0 comments on commit fc9a413

Please sign in to comment.