From 315d5f9f6f8938e0ceade1cd63ffffe2d5f2bfb6 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Mon, 9 Sep 2019 08:02:54 -0700 Subject: [PATCH] Update to the new Function interface and update some byte->bool indexing --- qpth/qp.py | 332 ++++++++++++++++++------------------ qpth/solvers/pdipm/batch.py | 2 +- qpth/util.py | 2 +- test.py | 6 +- 4 files changed, 169 insertions(+), 173 deletions(-) diff --git a/qpth/qp.py b/qpth/qp.py index 76a4da4..794f63f 100755 --- a/qpth/qp.py +++ b/qpth/qp.py @@ -15,176 +15,172 @@ class QPSolvers(Enum): CVXPY = 2 -class QPFunction(Function): - def __init__(self, eps=1e-12, verbose=0, notImprovedLim=3, +def QPFunction(eps=1e-12, verbose=0, notImprovedLim=3, maxIter=20, solver=QPSolvers.PDIPM_BATCHED, check_Q_spd=True): - self.eps = eps - self.verbose = verbose - self.notImprovedLim = notImprovedLim - self.maxIter = maxIter - self.solver = solver - self.check_Q_spd = check_Q_spd - - def forward(self, Q_, p_, G_, h_, A_, b_): - """Solve a batch of QPs. - - This function solves a batch of QPs, each optimizing over - `nz` variables and having `nineq` inequality constraints - and `neq` equality constraints. - The optimization problem for each instance in the batch - (dropping indexing from the notation) is of the form - - \hat z = argmin_z 1/2 z^T Q z + p^T z - subject to Gz <= h - Az = b - - where Q \in S^{nz,nz}, - S^{nz,nz} is the set of all positive semi-definite matrices, - p \in R^{nz} - G \in R^{nineq,nz} - h \in R^{nineq} - A \in R^{neq,nz} - b \in R^{neq} - - These parameters should all be passed to this function as - Variable- or Parameter-wrapped Tensors. - (See torch.autograd.Variable and torch.nn.parameter.Parameter) - - If you want to solve a batch of QPs where `nz`, `nineq` and `neq` - are the same, but some of the contents differ across the - minibatch, you can pass in tensors in the standard way - where the first dimension indicates the batch example. - This can be done with some or all of the coefficients. - - You do not need to add an extra dimension to coefficients - that will not change across all of the minibatch examples. - This function is able to infer such cases. - - If you don't want to use any equality or inequality constraints, - you can set the appropriate values to: - - e = Variable(torch.Tensor()) - - Parameters: - Q: A (nBatch, nz, nz) or (nz, nz) Tensor. - p: A (nBatch, nz) or (nz) Tensor. - G: A (nBatch, nineq, nz) or (nineq, nz) Tensor. - h: A (nBatch, nineq) or (nineq) Tensor. - A: A (nBatch, neq, nz) or (neq, nz) Tensor. - b: A (nBatch, neq) or (neq) Tensor. - - Returns: \hat z: a (nBatch, nz) Tensor. - """ - nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_) - Q, _ = expandParam(Q_, nBatch, 3) - p, _ = expandParam(p_, nBatch, 2) - G, _ = expandParam(G_, nBatch, 3) - h, _ = expandParam(h_, nBatch, 2) - A, _ = expandParam(A_, nBatch, 3) - b, _ = expandParam(b_, nBatch, 2) - - if self.check_Q_spd: - for i in range(nBatch): - e, _ = torch.eig(Q[i]) - if not torch.all(e[:,0] > 0): - raise RuntimeError('Q is not SPD.') - - _, nineq, nz = G.size() - neq = A.size(1) if A.nelement() > 0 else 0 - assert(neq > 0 or nineq > 0) - self.neq, self.nineq, self.nz = neq, nineq, nz - - if self.solver == QPSolvers.PDIPM_BATCHED: - self.Q_LU, self.S_LU, self.R = pdipm_b.pre_factor_kkt(Q, G, A) - zhats, self.nus, self.lams, self.slacks = pdipm_b.forward( - Q, p, G, h, A, b, self.Q_LU, self.S_LU, self.R, - self.eps, self.verbose, self.notImprovedLim, self.maxIter) - elif self.solver == QPSolvers.CVXPY: - vals = torch.Tensor(nBatch).type_as(Q) - zhats = torch.Tensor(nBatch, self.nz).type_as(Q) - lams = torch.Tensor(nBatch, self.nineq).type_as(Q) - nus = torch.Tensor(nBatch, self.neq).type_as(Q) \ - if self.neq > 0 else torch.Tensor() - slacks = torch.Tensor(nBatch, self.nineq).type_as(Q) - for i in range(nBatch): - Ai, bi = (A[i], b[i]) if neq > 0 else (None, None) - vals[i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np( - *[x.cpu().numpy() if x is not None else None - for x in (Q[i], p[i], G[i], h[i], Ai, bi)]) - # if zhati[0] is None: - # import IPython, sys; IPython.embed(); sys.exit(-1) - zhats[i] = torch.Tensor(zhati) - lams[i] = torch.Tensor(lami) - slacks[i] = torch.Tensor(si) - if neq > 0: - nus[i] = torch.Tensor(nui) - - self.vals = vals - self.lams = lams - self.nus = nus - self.slacks = slacks - else: - assert False - - self.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_) - return zhats - - def backward(self, dl_dzhat): - zhats, Q, p, G, h, A, b = self.saved_tensors - nBatch = extract_nBatch(Q, p, G, h, A, b) - Q, Q_e = expandParam(Q, nBatch, 3) - p, p_e = expandParam(p, nBatch, 2) - G, G_e = expandParam(G, nBatch, 3) - h, h_e = expandParam(h, nBatch, 2) - A, A_e = expandParam(A, nBatch, 3) - b, b_e = expandParam(b, nBatch, 2) - - # neq, nineq, nz = self.neq, self.nineq, self.nz - neq, nineq = self.neq, self.nineq - - - if self.solver == QPSolvers.CVXPY: - self.Q_LU, self.S_LU, self.R = pdipm_b.pre_factor_kkt(Q, G, A) - - # Clamp here to avoid issues coming up when the slacks are too small. - # TODO: A better fix would be to get lams and slacks from the - # solver that don't have this issue. - d = torch.clamp(self.lams, min=1e-8) / torch.clamp(self.slacks, min=1e-8) - - pdipm_b.factor_kkt(self.S_LU, self.R, d) - dx, _, dlam, dnu = pdipm_b.solve_kkt( - self.Q_LU, d, G, A, self.S_LU, - dl_dzhat, torch.zeros(nBatch, nineq).type_as(G), - torch.zeros(nBatch, nineq).type_as(G), - torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor()) - - dps = dx - dGs = bger(dlam, zhats) + bger(self.lams, dx) - if G_e: - dGs = dGs.mean(0) - dhs = -dlam - if h_e: - dhs = dhs.mean(0) - if neq > 0: - dAs = bger(dnu, zhats) + bger(self.nus, dx) - dbs = -dnu - if A_e: - dAs = dAs.mean(0) - if b_e: - dbs = dbs.mean(0) - else: - dAs, dbs = None, None - dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx)) - if Q_e: - dQs = dQs.mean(0) - if p_e: - dps = dps.mean(0) - - - grads = (dQs, dps, dGs, dhs, dAs, dbs) - - return grads + class QPFunctionFn(Function): + @staticmethod + def forward(ctx, Q_, p_, G_, h_, A_, b_): + """Solve a batch of QPs. + + This function solves a batch of QPs, each optimizing over + `nz` variables and having `nineq` inequality constraints + and `neq` equality constraints. + The optimization problem for each instance in the batch + (dropping indexing from the notation) is of the form + + \hat z = argmin_z 1/2 z^T Q z + p^T z + subject to Gz <= h + Az = b + + where Q \in S^{nz,nz}, + S^{nz,nz} is the set of all positive semi-definite matrices, + p \in R^{nz} + G \in R^{nineq,nz} + h \in R^{nineq} + A \in R^{neq,nz} + b \in R^{neq} + + These parameters should all be passed to this function as + Variable- or Parameter-wrapped Tensors. + (See torch.autograd.Variable and torch.nn.parameter.Parameter) + + If you want to solve a batch of QPs where `nz`, `nineq` and `neq` + are the same, but some of the contents differ across the + minibatch, you can pass in tensors in the standard way + where the first dimension indicates the batch example. + This can be done with some or all of the coefficients. + + You do not need to add an extra dimension to coefficients + that will not change across all of the minibatch examples. + This function is able to infer such cases. + + If you don't want to use any equality or inequality constraints, + you can set the appropriate values to: + + e = Variable(torch.Tensor()) + + Parameters: + Q: A (nBatch, nz, nz) or (nz, nz) Tensor. + p: A (nBatch, nz) or (nz) Tensor. + G: A (nBatch, nineq, nz) or (nineq, nz) Tensor. + h: A (nBatch, nineq) or (nineq) Tensor. + A: A (nBatch, neq, nz) or (neq, nz) Tensor. + b: A (nBatch, neq) or (neq) Tensor. + + Returns: \hat z: a (nBatch, nz) Tensor. + """ + nBatch = extract_nBatch(Q_, p_, G_, h_, A_, b_) + Q, _ = expandParam(Q_, nBatch, 3) + p, _ = expandParam(p_, nBatch, 2) + G, _ = expandParam(G_, nBatch, 3) + h, _ = expandParam(h_, nBatch, 2) + A, _ = expandParam(A_, nBatch, 3) + b, _ = expandParam(b_, nBatch, 2) + + if check_Q_spd: + for i in range(nBatch): + e, _ = torch.eig(Q[i]) + if not torch.all(e[:,0] > 0): + raise RuntimeError('Q is not SPD.') + + _, nineq, nz = G.size() + neq = A.size(1) if A.nelement() > 0 else 0 + assert(neq > 0 or nineq > 0) + ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz + + if solver == QPSolvers.PDIPM_BATCHED: + ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) + zhats, ctx.nus, ctx.lams, ctx.slacks = pdipm_b.forward( + Q, p, G, h, A, b, ctx.Q_LU, ctx.S_LU, ctx.R, + eps, verbose, notImprovedLim, maxIter) + elif solver == QPSolvers.CVXPY: + vals = torch.Tensor(nBatch).type_as(Q) + zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q) + lams = torch.Tensor(nBatch, ctx.nineq).type_as(Q) + nus = torch.Tensor(nBatch, ctx.neq).type_as(Q) \ + if ctx.neq > 0 else torch.Tensor() + slacks = torch.Tensor(nBatch, ctx.nineq).type_as(Q) + for i in range(nBatch): + Ai, bi = (A[i], b[i]) if neq > 0 else (None, None) + vals[i], zhati, nui, lami, si = solvers.cvxpy.forward_single_np( + *[x.cpu().numpy() if x is not None else None + for x in (Q[i], p[i], G[i], h[i], Ai, bi)]) + # if zhati[0] is None: + # import IPython, sys; IPython.embed(); sys.exit(-1) + zhats[i] = torch.Tensor(zhati) + lams[i] = torch.Tensor(lami) + slacks[i] = torch.Tensor(si) + if neq > 0: + nus[i] = torch.Tensor(nui) + + ctx.vals = vals + ctx.lams = lams + ctx.nus = nus + ctx.slacks = slacks + else: + assert False + + ctx.save_for_backward(zhats, Q_, p_, G_, h_, A_, b_) + return zhats + + @staticmethod + def backward(ctx, dl_dzhat): + zhats, Q, p, G, h, A, b = ctx.saved_tensors + nBatch = extract_nBatch(Q, p, G, h, A, b) + Q, Q_e = expandParam(Q, nBatch, 3) + p, p_e = expandParam(p, nBatch, 2) + G, G_e = expandParam(G, nBatch, 3) + h, h_e = expandParam(h, nBatch, 2) + A, A_e = expandParam(A, nBatch, 3) + b, b_e = expandParam(b, nBatch, 2) + + # neq, nineq, nz = ctx.neq, ctx.nineq, ctx.nz + neq, nineq = ctx.neq, ctx.nineq + + + if solver == QPSolvers.CVXPY: + ctx.Q_LU, ctx.S_LU, ctx.R = pdipm_b.pre_factor_kkt(Q, G, A) + + # Clamp here to avoid issues coming up when the slacks are too small. + # TODO: A better fix would be to get lams and slacks from the + # solver that don't have this issue. + d = torch.clamp(ctx.lams, min=1e-8) / torch.clamp(ctx.slacks, min=1e-8) + + pdipm_b.factor_kkt(ctx.S_LU, ctx.R, d) + dx, _, dlam, dnu = pdipm_b.solve_kkt( + ctx.Q_LU, d, G, A, ctx.S_LU, + dl_dzhat, torch.zeros(nBatch, nineq).type_as(G), + torch.zeros(nBatch, nineq).type_as(G), + torch.zeros(nBatch, neq).type_as(G) if neq > 0 else torch.Tensor()) + + dps = dx + dGs = bger(dlam, zhats) + bger(ctx.lams, dx) + if G_e: + dGs = dGs.mean(0) + dhs = -dlam + if h_e: + dhs = dhs.mean(0) + if neq > 0: + dAs = bger(dnu, zhats) + bger(ctx.nus, dx) + dbs = -dnu + if A_e: + dAs = dAs.mean(0) + if b_e: + dbs = dbs.mean(0) + else: + dAs, dbs = None, None + dQs = 0.5 * (bger(dx, zhats) + bger(zhats, dx)) + if Q_e: + dQs = dQs.mean(0) + if p_e: + dps = dps.mean(0) + + + grads = (dQs, dps, dGs, dhs, dAs, dbs) + + return grads + return QPFunctionFn.apply class SpQPFunction(Function): diff --git a/qpth/solvers/pdipm/batch.py b/qpth/solvers/pdipm/batch.py index b6ae86d..23abde0 100644 --- a/qpth/solvers/pdipm/batch.py +++ b/qpth/solvers/pdipm/batch.py @@ -437,7 +437,7 @@ def factor_kkt(S_LU, R, d): if factor_kkt_eye is None or factor_kkt_eye.size() != d.size(): # print('Updating batchedEye size.') factor_kkt_eye = torch.eye(nineq).repeat( - nBatch, 1, 1).type_as(R).byte() + nBatch, 1, 1).type_as(R).bool() T = R.clone() T[factor_kkt_eye] += (1. / d).squeeze().view(-1) diff --git a/qpth/util.py b/qpth/util.py index e020248..80050b1 100644 --- a/qpth/util.py +++ b/qpth/util.py @@ -36,7 +36,7 @@ def get_sizes(G, A=None): def bdiag(d): nBatch, sz = d.size() D = torch.zeros(nBatch, sz, sz).type_as(d) - I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).byte() + I = torch.eye(sz).repeat(nBatch, 1, 1).type_as(d).bool() D[I] = d.squeeze().view(-1) return D diff --git a/test.py b/test.py index 4b81ec5..be3ec55 100755 --- a/test.py +++ b/test.py @@ -13,7 +13,7 @@ import numpy as np import numpy.random as npr import numpy.testing as npt -from numpy.testing import decorators +from numpy.testing import dec np.set_printoptions(precision=6) import numdifftools as nd @@ -247,7 +247,7 @@ def test_ir_kkt_solver(): npt.assert_allclose(dy.numpy(), dy_.numpy(), rtol=RTOL, atol=ATOL) -@npt.decorators.skipif( +@npt.dec.skipif( not torch.cuda.is_available() or not hasattr(torch, 'spbqrfactsolve')) def test_sparse_forward(): torch.manual_seed(0) @@ -300,7 +300,7 @@ def cast(m): xhats_qpf.cpu().numpy(), rtol=RTOL, atol=ATOL) -@npt.decorators.skipif( +@npt.dec.skipif( not torch.cuda.is_available() or not hasattr(torch, 'spbqrfactsolve')) def test_sparse_backward(): torch.manual_seed(0)