diff --git a/qpth/solvers/pdipm/batch.py b/qpth/solvers/pdipm/batch.py index 4bd5970..b6ae86d 100644 --- a/qpth/solvers/pdipm/batch.py +++ b/qpth/solvers/pdipm/batch.py @@ -291,15 +291,15 @@ def factor_solve_kkt_reg(Q_tilde, D, G, A, rx, rs, rz, ry, eps): H_LU = lu_hack(H_) invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU) - invH_g_ = g_.lu_solve(*H_LU) + invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) S_ = torch.bmm(A_, invH_A_) S_ -= eps * torch.eye(neq + nineq).type_as(Q_tilde).repeat(nBatch, 1, 1) S_LU = lu_hack(S_) t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ - w_ = -t_.lu_solve(*S_LU) + w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2) t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze() - v_ = t_.lu_solve(*H_LU) + v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) dx = v_[:, :nz] ds = v_[:, nz:] @@ -328,14 +328,14 @@ def factor_solve_kkt(Q, D, G, A, rx, rs, rz, ry): H_LU = lu_hack(H_) invH_A_ = A_.transpose(1, 2).lu_solve(*H_LU) - invH_g_ = g_.lu_solve(*H_LU) + invH_g_ = g_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) S_ = torch.bmm(A_, invH_A_) S_LU = lu_hack(S_) t_ = torch.bmm(invH_g_.unsqueeze(1), A_.transpose(1, 2)).squeeze(1) - h_ - w_ = -t_.lu_solve(*S_LU) + w_ = -t_.unsqueeze(2).lu_solve(*S_LU).squeeze(2) t_ = -g_ - w_.unsqueeze(1).bmm(A_).squeeze() - v_ = t_.lu_solve(*H_LU) + v_ = t_.unsqueeze(2).lu_solve(*H_LU).squeeze(2) dx = v_[:, :nz] ds = v_[:, nz:] @@ -349,21 +349,21 @@ def solve_kkt(Q_LU, d, G, A, S_LU, rx, rs, rz, ry): """ Solve KKT equations for the affine step""" nineq, nz, neq, nBatch = get_sizes(G, A) - invQ_rx = rx.lu_solve(*Q_LU) + invQ_rx = rx.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) if neq > 0: h = torch.cat((invQ_rx.unsqueeze(1).bmm(A.transpose(1, 2)).squeeze(1) - ry, invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz), 1) else: h = invQ_rx.unsqueeze(1).bmm(G.transpose(1, 2)).squeeze(1) + rs / d - rz - w = -(h.lu_solve(*S_LU)) + w = -(h.unsqueeze(2).lu_solve(*S_LU)).squeeze(2) g1 = -rx - w[:, neq:].unsqueeze(1).bmm(G).squeeze(1) if neq > 0: g1 -= w[:, :neq].unsqueeze(1).bmm(A).squeeze(1) g2 = -rs - w[:, neq:] - dx = g1.lu_solve(*Q_LU) + dx = g1.unsqueeze(2).lu_solve(*Q_LU).squeeze(2) ds = g2 / d dz = w[:, neq:] dy = w[:, :neq] if neq > 0 else None