From ca179ab91052c3e33e96f2a45fb11d07ea1c3dc6 Mon Sep 17 00:00:00 2001 From: Armand Jordana Date: Wed, 7 Aug 2024 14:55:04 -0400 Subject: [PATCH] added bindings for dx and du --- bindings/csqp.cpp | 7 ++++--- bindings/sqp.cpp | 3 +++ include/mim_solvers/csqp.hpp | 7 +++++-- include/mim_solvers/sqp.hpp | 3 +++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/bindings/csqp.cpp b/bindings/csqp.cpp index 79602a1..14157aa 100644 --- a/bindings/csqp.cpp +++ b/bindings/csqp.cpp @@ -93,9 +93,10 @@ void exposeSolverCSQP() { "Additional iteration if SQP max. iter reached (default: False)") .add_property("xs", make_function(&SolverCSQP::get_xs, bp::return_value_policy()), bp::make_function(&SolverCSQP::set_xs), "xs") .add_property("us", make_function(&SolverCSQP::get_us, bp::return_value_policy()), bp::make_function(&SolverCSQP::set_us), "us") - .add_property("dx_tilde", make_function(&SolverCSQP::get_xs_tilde, bp::return_value_policy()), "dx_tilde") - .add_property("du_tilde", make_function(&SolverCSQP::get_us_tilde, bp::return_value_policy()), "du_tilde") - + .add_property("dx_tilde", make_function(&SolverCSQP::get_dx_tilde, bp::return_value_policy()), "dx_tilde") + .add_property("du_tilde", make_function(&SolverCSQP::get_du_tilde, bp::return_value_policy()), "du_tilde") + .add_property("dx", make_function(&SolverCSQP::get_dx, bp::return_value_policy()), "dx") + .add_property("du", make_function(&SolverCSQP::get_du, bp::return_value_policy()), "du") .add_property("constraint_norm", bp::make_function(&SolverCSQP::get_constraint_norm), diff --git a/bindings/sqp.cpp b/bindings/sqp.cpp index 986888d..aface03 100644 --- a/bindings/sqp.cpp +++ b/bindings/sqp.cpp @@ -53,6 +53,9 @@ void exposeSolverSQP() { .def_readwrite("fs_try", &SolverSQP::fs_try_, "fs_try") .def_readwrite("lag_mul", &SolverSQP::lag_mul_, "lagrange multipliers") + .add_property("dx", make_function(&SolverSQP::get_dx, bp::return_value_policy()), "dx") + .add_property("du", make_function(&SolverSQP::get_du, bp::return_value_policy()), "du") + .add_property("KKT", bp::make_function(&SolverSQP::get_KKT), "KKT residual norm") diff --git a/include/mim_solvers/csqp.hpp b/include/mim_solvers/csqp.hpp index f5f832b..dd94738 100644 --- a/include/mim_solvers/csqp.hpp +++ b/include/mim_solvers/csqp.hpp @@ -118,8 +118,11 @@ class SolverCSQP : public SolverDDP { const std::vector& get_xs() const { return xs_; }; const std::vector& get_us() const { return us_; }; - const std::vector& get_xs_tilde() const { return dxtilde_; }; - const std::vector& get_us_tilde() const { return dutilde_; }; + const std::vector& get_dx_tilde() const { return dxtilde_; }; + const std::vector& get_du_tilde() const { return dutilde_; }; + + const std::vector& get_dx() const { return dx_; }; + const std::vector& get_du() const { return du_; }; const std::vector& get_y() const { return y_; }; const std::vector& get_z() const { return z_; }; diff --git a/include/mim_solvers/sqp.hpp b/include/mim_solvers/sqp.hpp index 6a8e893..81eb423 100644 --- a/include/mim_solvers/sqp.hpp +++ b/include/mim_solvers/sqp.hpp @@ -101,6 +101,9 @@ class SolverSQP : public SolverDDP { const std::vector& get_xs_try() const { return xs_try_; }; const std::vector& get_us_try() const { return us_try_; }; + const std::vector& get_dx() const { return dx_; }; + const std::vector& get_du() const { return du_; }; + double get_KKT() const { return KKT_; }; double get_gap_norm() const { return gap_norm_; }; double get_xgrad_norm() const { return x_grad_norm_; };