Skip to content

Commit

Permalink
Update lu_solve calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
bamos committed Sep 8, 2019
1 parent f4ad1d0 commit bb156fe
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions qpth/solvers/pdipm/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,15 +291,15 @@ def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps):
H_LU = lu_hack(H_)

invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)
invH_g_ = g_.lu_solve(*H_LU)
invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

S_ = torch.bmm(A_, invH_A_)
S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1)
S_LU = lu_hack(S_)
t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
w_ = -t_.lu_solve(*S_LU)
w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2)
t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
v_ = t_.lu_solve(*H_LU)
v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

dx = v_[:, :nz]
ds = v_[:, nz:]
Expand Down Expand Up @@ -328,14 +328,14 @@ def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry):
H_LU = lu_hack(H_)

invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU)
invH_g_ = g_.lu_solve(*H_LU)
invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

S_ = torch.bmm(A_, invH_A_)
S_LU = lu_hack(S_)
t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_
w_ = -t_.lu_solve(*S_LU)
w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2)
t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze()
v_ = t_.lu_solve(*H_LU)
v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2)

dx = v_[:, :nz]
ds = v_[:, nz:]
Expand All @@ -349,21 +349,21 @@ def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry):
""" Solve KKT equations for the affine step"""
nineq, nz, neq, nBatch = get_sizes(G, A)

invQ_rx = rx.lu_solve(*Q_LU)
invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
if neq > 0:
h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry,
invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1)
else:
h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz

w = -(h.lu_solve(*S_LU))
w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2)

g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1)
if neq > 0:
g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1)
g2 = -rs - w[:, neq:]

dx = g1.lu_solve(*Q_LU)
dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2)
ds = g2 / d
dz = w[:, neq:]
dy = w[:, :neq] if neq > 0 else None
Expand Down

0 comments on commit bb156fe

Please sign in to comment.