diff --git a/include/mujincontrollerclient/mujinjsonmsgpack.h b/include/mujincontrollerclient/mujinjsonmsgpack.h new file mode 100644 index 00000000..3906df71 --- /dev/null +++ b/include/mujincontrollerclient/mujinjsonmsgpack.h @@ -0,0 +1,281 @@ +#ifndef MUJIN_CONTROLLERCLIENT_JSONMSGPACK_H +#define MUJIN_CONTROLLERCLIENT_JSONMSGPACK_H +#include "mujincontrollerclient/mujinjson.h" +#include "msgpack.hpp" +#include "mujinmasterslaveclient.h" + +template +struct MsgpackParser { + /* + * Fast parser for turning msgpack -> json + * without intermediate msgpack_object repr. + */ + + bool visit_nil() + { + _objects.emplace_back(rapidjson::kNullType); + return true; + } + + bool visit_boolean(const bool value) + { + _objects.emplace_back(value); + return true; + } + + bool visit_positive_integer(const uint64_t value) + { + _objects.emplace_back(value); + return true; + } + + bool visit_negative_integer(const int64_t value) + { + _objects.emplace_back(value); + return true; + } + + bool visit_float32(const float value) + { + _objects.emplace_back(value); + return true; + } + + bool visit_float64(const double value) + { + _objects.emplace_back(value); + return true; + } + + bool visit_str(const char* const value, const uint32_t size) + { + _objects.emplace_back(value, size, _allocator); + return true; + } + + bool visit_bin(const char* const value, const uint32_t size) + { + _objects.emplace_back(value, size, _allocator); + return true; + } + + bool visit_ext(const char* value, const uint32_t valueSize) + { + msgpack::object object; + object.type = msgpack::type::EXT; + object.via.ext.ptr = value; + object.via.ext.size = valueSize - 1; + + const std::chrono::system_clock::time_point tp = object.as(); + const std::time_t parsedTime = std::chrono::system_clock::to_time_t(tp); + + // RFC 3339 Nano format + char formatted[sizeof("2006-01-02T15:04:05.999999999Z07:00")]; + + // The extension does not include timezone information. By convention, we format to local time. + tm datetime = {}; + std::size_t size = std::strftime(formatted, sizeof(formatted), "%FT%T", localtime_r(&parsedTime, &datetime)); + + // Add nanoseconds portion if present + const long nanoseconds = (std::chrono::duration_cast(tp.time_since_epoch()).count() % 1000000000 + 1000000000) % 1000000000; + if (nanoseconds != 0) { + size += sprintf(formatted + size, ".%09lu", nanoseconds); + // remove trailing zeros + while (formatted[size - 1] == '0') { + --size; + } + } + if (datetime.tm_gmtoff == 0) { + formatted[size] = 'Z'; + } else { + size += std::strftime(formatted + size, sizeof(formatted) - size, "%z", &datetime); + // fix timezone format (0000 -> 00:00) + formatted[size] = formatted[size - 1]; + formatted[size - 1] = formatted[size - 2]; + formatted[size - 2] = ':'; + } + + _objects.emplace_back(formatted, size + 1, _allocator); + return true; + } + + bool start_array(const uint32_t size) + { + _objects.emplace_back(rapidjson::kArrayType); + _objects.back().Reserve(size, _allocator); + return true; + } + + static bool start_array_item() + { + return true; + } + + bool end_array_item() + { + rapidjson::Value top = std::move(_objects.back()); + _objects.pop_back(); + + _objects.back().PushBack(top, _allocator); + return true; + } + + static bool end_array() + { + return true; + } + + bool start_map(const uint32_t size) + { + _objects.emplace_back(rapidjson::kObjectType); + _objects.back().MemberReserve(size, _allocator); + return true; + } + + static bool start_map_key() + { + return true; + } + + static bool end_map_key() + { + return true; + } + + static bool start_map_value() + { + return true; + } + + bool end_map_value() + { + rapidjson::Value value = std::move(_objects.back()); + _objects.pop_back(); + + rapidjson::Value key = std::move(_objects.back()); + _objects.pop_back(); + + _objects.back().AddMember(key, value, _allocator); + return true; + } + + static bool end_map() + { + return true; + } + + rapidjson::Value Extract() + { + if (_objects.size() != 1) { + throw msgpack::parse_error("parse error"); + } + rapidjson::Value result = std::move(_objects.back()); + _objects.pop_back(); + return result; + } + + static void parse_error(size_t /*parsed_offset*/, size_t /*error_offset*/) + { + throw msgpack::parse_error("parse error"); + } + + static void insufficient_bytes(size_t /*parsed_offset*/, size_t /*error_offset*/) + { + throw msgpack::insufficient_bytes("insufficient bytes"); + } + + explicit MsgpackParser(Allocator& allocator): _allocator(allocator) + { + } + +private: + std::vector > _objects; + Allocator& _allocator; +}; + +namespace mujinmasterslaveclient { +template +struct MessageParser > : MsgpackParser { + explicit MessageParser(rapidjson::GenericDocument document = {}): + MsgpackParser(document.GetAllocator()), + _document(std::move(document)) + { + } + + rapidjson::GenericDocument Extract() + { + MsgpackParser::Extract().Swap(_document); + return std::move(_document); + } + +private: + rapidjson::GenericDocument _document; +}; +} + +using GenericMsgpackParser = MsgpackParser; + +namespace msgpack { +MSGPACK_API_VERSION_NAMESPACE(MSGPACK_DEFAULT_API_NS) { +namespace adaptor { +template +struct pack > { + template + packer& operator()(packer& o, rapidjson::GenericValue const& v) const + { + switch (v.GetType()) { + case rapidjson::kNullType: + return o.pack_nil(); + case rapidjson::kFalseType: + return o.pack_false(); + case rapidjson::kTrueType: + return o.pack_true(); + case rapidjson::kObjectType: { + o.pack_map(v.MemberCount()); + typename rapidjson::GenericValue::ConstMemberIterator i = v.MemberBegin(), END = v.MemberEnd(); + for (; i != END; ++i) { + o.pack(i->name); + o.pack(i->value); + } + return o; + } + case rapidjson::kArrayType: { + o.pack_array(v.Size()); + typename rapidjson::GenericValue::ConstValueIterator i = v.Begin(), END = v.End(); + for (; i < END; ++i) { + o.pack(*i); + } + return o; + } + case rapidjson::kStringType: + return o.pack_str(v.GetStringLength()).pack_str_body(v.GetString(), v.GetStringLength()); + case rapidjson::kNumberType: + if (v.IsInt()) + return o.pack_int(v.GetInt()); + if (v.IsUint()) + return o.pack_unsigned_int(v.GetUint()); + if (v.IsInt64()) + return o.pack_int64(v.GetInt64()); + if (v.IsUint64()) + return o.pack_uint64(v.GetUint64()); + if (v.IsDouble()) + return o.pack_double(v.GetDouble()); + default: + return o; + } + } +}; + +template +struct pack > { + template + packer& operator()(packer& o, rapidjson::GenericDocument const& v) const + { + return o.pack(static_cast&>(v)); + } +}; +} +} +} + +#endif //MUJIN_CONTROLLERCLIENT_JSONMSGPACK_H diff --git a/include/mujincontrollerclient/mujinmasterslaveclient.h b/include/mujincontrollerclient/mujinmasterslaveclient.h new file mode 100644 index 00000000..eebafb04 --- /dev/null +++ b/include/mujincontrollerclient/mujinmasterslaveclient.h @@ -0,0 +1,103 @@ +#ifndef MUJIN_MASTERSLAVECLIENT_REQUESTSOCKET_H +#define MUJIN_MASTERSLAVECLIENT_REQUESTSOCKET_H +#include +#include +#include + +namespace mujinmasterslaveclient { +template +struct MessageParser; + +template +static zmq::message_t EncodeToFrame(const ValueType& value) +{ + msgpack::sbuffer buffer; + msgpack::pack(buffer, value); + return {buffer.data(), buffer.size()}; +} + +template +static std::vector EncodeToMessage(const ValueType& value) +{ + std::vector messages; + messages.emplace_back(EncodeToFrame(value)); + return messages; +} + +template +static ValueType DecodeFromFrame(const zmq::message_t& frame) +{ + MessageParser parser; + if (!msgpack::parse(frame.data(), frame.size(), parser)) { + throw std::invalid_argument("unable to parse"); + } + return parser.Extract(); +} + +struct RequestSocket : private zmq::socket_t { + RequestSocket(zmq::context_t& context, const std::string& address); + + void SendNoWait(std::vector&& messages); + + std::vector ReceiveNoWait(); + + [[nodiscard]] bool Poll(short events, std::chrono::milliseconds timeout); + + std::vector SendAndReceive(std::vector&& messages, std::chrono::milliseconds timeout); +}; + +template +static OutputType SendAndReceive(RequestSocket& socket, const InputType& master, const std::chrono::milliseconds timeout) +{ + std::vector frames; + frames.emplace_back(EncodeToFrame(master)); + const std::vector response = socket.SendAndReceive(std::move(frames), timeout); + if (response.size() != 1) { + throw mujinclient::MujinException("unexpected server response protocol", mujinclient::MEC_InvalidState); + } + return DecodeFromFrame(response.front()); +} + +template +static OutputType SendAndReceive(RequestSocket& socket, const InputType& master, const InputType& slave, const std::chrono::milliseconds timeout) +{ + std::vector frames; + frames.emplace_back(EncodeToFrame(master)); + frames.emplace_back(EncodeToFrame(slave)); + const std::vector response = socket.SendAndReceive(std::move(frames), timeout); + if (response.size() != 1) { + throw mujinclient::MujinException("unexpected server response protocol", mujinclient::MEC_InvalidState); + } + return DecodeFromFrame(response.front()); +} + +template +static bool SendAndReceive(RequestSocket& socket, const InputType& master, Parser& parser, const std::chrono::milliseconds timeout) +{ + std::vector frames; + frames.emplace_back(EncodeToFrame(master)); + const std::vector response = socket.SendAndReceive(std::move(frames), timeout); + if (response.size() != 1) { + throw mujinclient::MujinException("unexpected server response protocol", mujinclient::MEC_InvalidState); + } + const zmq::message_t& frame = response.front(); + return msgpack::parse(frame.data(), frame.size(), parser); +} + + +template +static bool SendAndReceive(RequestSocket& socket, const InputType& master, const InputType& slave, Parser& parser, const std::chrono::milliseconds timeout) +{ + std::vector frames; + frames.emplace_back(EncodeToFrame(master)); + frames.emplace_back(EncodeToFrame(slave)); + const std::vector response = socket.SendAndReceive(std::move(frames), timeout); + if (response.size() != 1) { + throw mujinclient::MujinException("unexpected server response protocol", mujinclient::MEC_InvalidState); + } + const zmq::message_t& frame = response.front(); + return msgpack::parse(frame.data(), frame.size(), parser); +} +} + +#endif //MUJIN_MASTERSLAVECLIENT_REQUESTSOCKET_H diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d975da3f..7fa2cce4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -29,6 +29,7 @@ set(SOURCE_FILES mujincontrollerclient.cpp mujindefinitions.cpp mujinjson.cpp + mujinmasterslaveclient.cpp utf8.h ) if (libzmq_FOUND) diff --git a/src/binpickingtask.cpp b/src/binpickingtask.cpp index b20f4d54..b3a145e1 100644 --- a/src/binpickingtask.cpp +++ b/src/binpickingtask.cpp @@ -18,6 +18,8 @@ #endif #include // for sleep #include "mujincontrollerclient/binpickingtask.h" +#include "mujincontrollerclient/mujinjsonmsgpack.h" +#include "mujincontrollerclient/mujinmasterslaveclient.h" #ifdef MUJIN_USEZMQ #include "mujincontrollerclient/zmq.hpp" @@ -563,11 +565,8 @@ BinPickingTaskResource::ResultGetJointValues::~ResultGetJointValues() { } -void BinPickingTaskResource::ResultGetJointValues::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultGetJointValues::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; - LoadJsonValueByKey(v, "robottype", robottype); LoadJsonValueByKey(v, "jointnames", jointnames); LoadJsonValueByKey(v, "currentjointvalues", currentjointvalues); @@ -587,10 +586,8 @@ BinPickingTaskResource::ResultMoveJoints::~ResultMoveJoints() { } -void BinPickingTaskResource::ResultMoveJoints::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultMoveJoints::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; LoadJsonValueByKey(v, "robottype", robottype); LoadJsonValueByKey(v, "timedjointvalues", timedjointvalues); LoadJsonValueByKey(v, "numpoints", numpoints); @@ -600,11 +597,8 @@ BinPickingTaskResource::ResultTransform::~ResultTransform() { } -void BinPickingTaskResource::ResultTransform::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultTransform::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; - LoadJsonValueByKey(v, "translation", transform.translate); LoadJsonValueByKey(v, "quaternion", transform.quaternion); } @@ -613,11 +607,8 @@ BinPickingTaskResource::ResultInstObjectInfo::~ResultInstObjectInfo() { } -void BinPickingTaskResource::ResultInstObjectInfo::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultInstObjectInfo::Parse(const rapidjson::Value& rOutput) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& rOutput = pt["output"]; - LoadJsonValueByKey(rOutput, "translation", instobjecttransform.translate); LoadJsonValueByKey(rOutput, "quaternion", instobjecttransform.quaternion); instobjectobb.Parse(rOutput["obb"]); @@ -636,11 +627,8 @@ BinPickingTaskResource::ResultGetInstObjectAndSensorInfo::~ResultGetInstObjectAn { } -void BinPickingTaskResource::ResultGetInstObjectAndSensorInfo::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultGetInstObjectAndSensorInfo::Parse(const rapidjson::Value& output) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& output = pt["output"]; - mrGeometryInfos.clear(); const rapidjson::Value& instobjects = output["instobjects"]; @@ -727,11 +715,8 @@ BinPickingTaskResource::ResultGetBinpickingState::~ResultGetBinpickingState() { } -void BinPickingTaskResource::ResultGetBinpickingState::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultGetBinpickingState::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; - { rUnitInfo.SetNull(); rUnitInfo.GetAllocator().Clear(); @@ -843,20 +828,16 @@ BinPickingTaskResource::ResultIsRobotOccludingBody::~ResultIsRobotOccludingBody( { } -void BinPickingTaskResource::ResultIsRobotOccludingBody::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultIsRobotOccludingBody::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; if (!v.IsObject() || !v.HasMember("occluded")) { throw MujinException("Output does not have \"occluded\" attribute!", MEC_Failed); } result = GetJsonValueByKey(v, "occluded", 1) == 1; } -void BinPickingTaskResource::ResultGetPickedPositions::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultGetPickedPositions::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; for (rapidjson::Document::ConstMemberIterator value = v.MemberBegin(); value != v.MemberEnd(); ++value) { if (std::string(value->name.GetString()) == "positions" && value->value.IsArray()) { for (rapidjson::Document::ConstValueIterator it = value->value.Begin(); it != value->value.End(); ++it) { @@ -876,10 +857,8 @@ void BinPickingTaskResource::ResultGetPickedPositions::Parse(const rapidjson::Va } } -void BinPickingTaskResource::ResultAABB::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultAABB::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; LoadJsonValueByKey(v, "pos", pos); LoadJsonValueByKey(v, "extents", extents); if (pos.size() != 3) { @@ -890,19 +869,15 @@ void BinPickingTaskResource::ResultAABB::Parse(const rapidjson::Value& pt) } } -void BinPickingTaskResource::ResultOBB::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultOBB::Parse(const rapidjson::Value& v) { - const rapidjson::Value& v = (pt.IsObject()&&pt.HasMember("output") ? pt["output"] : pt); - LoadJsonValueByKey(v, "translation", translation); LoadJsonValueByKey(v, "extents", extents); LoadJsonValueByKey(v, "quaternion", quaternion); } -void BinPickingTaskResource::ResultComputeIkParamPosition::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultComputeIkParamPosition::Parse(const rapidjson::Value& v) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v = pt["output"]; LoadJsonValueByKey(v, "translation", translation); LoadJsonValueByKey(v, "quaternion", quaternion); LoadJsonValueByKey(v, "direction", direction); @@ -924,10 +899,8 @@ void BinPickingTaskResource::ResultComputeIkParamPosition::Parse(const rapidjson } } -void BinPickingTaskResource::ResultComputeIKFromParameters::Parse(const rapidjson::Value& pt) +void BinPickingTaskResource::ResultComputeIKFromParameters::Parse(const rapidjson::Value& v_) { - BOOST_ASSERT(pt.IsObject() && pt.HasMember("output")); - const rapidjson::Value& v_ = pt["output"]; BOOST_ASSERT(v_.IsObject() && v_.HasMember("solutions")); const rapidjson::Value& v = v_["solutions"]; for (rapidjson::Document::ConstValueIterator it = v.Begin(); it != v.End(); ++it) { @@ -1045,12 +1018,12 @@ GenerateMoveToolByIkParamCommand(const std::string &movetype, const std::string } -void SetTrajectory(const rapidjson::Value &pt, +void SetTrajectory(const rapidjson::Value &output, std::string *pTraj) { - if (!(pt.IsObject() && pt.HasMember("output") && pt["output"].HasMember("trajectory"))) { + if (!(output.IsObject() && output.HasMember("trajectory"))) { throw MujinException("trajectory is not available in output", MEC_Failed); } - *pTraj = GetJsonValueByPath(pt, "/output/trajectory"); + *pTraj = GetJsonValueByKey(output, "trajectory"); } } @@ -1717,10 +1690,8 @@ void BinPickingTaskResource::GetGrabbed(std::vector& grabbed, const if (!robotname.empty()) { SetJsonValueByKey(pt, "robotname", robotname); } - rapidjson::Document d; - ExecuteCommand(DumpJson(pt), d, timeout); // need to check return code - BOOST_ASSERT(d.IsObject() && d.HasMember("output")); - const rapidjson::Value& v = d["output"]; + rapidjson::Document v; + ExecuteCommand(DumpJson(pt), v, timeout); // need to check return code if(v.HasMember("names") && !v["names"].IsNull()) { LoadJsonValueByKey(v, "names", grabbed); } @@ -1775,10 +1746,8 @@ void BinPickingTaskResource::GetRobotBridgeIOVariableString(const std::vector::digits10+1); @@ -1964,21 +1935,71 @@ std::string utils::GetHeartbeat(const std::string& endpoint) { zmq::context_t zmqcontext(1); zmq::socket_t socket(zmqcontext, ZMQ_SUB); socket.connect(endpoint.c_str()); - socket.setsockopt(ZMQ_SUBSCRIBE, "", 0); + socket.set(zmq::sockopt::subscribe, "m"); - zmq::pollitem_t pollitem; - memset(&pollitem, 0, sizeof(zmq::pollitem_t)); - pollitem.socket = socket; - pollitem.events = ZMQ_POLLIN; - - zmq::poll(&pollitem,1, 50); // wait 50 ms for message + zmq::pollitem_t pollitem = { + .socket = socket, + .events = ZMQ_POLLIN, + }; + zmq::poll(&pollitem, 1, 50); // wait 50 ms for message if (!(pollitem.revents & ZMQ_POLLIN)) { return ""; } zmq::message_t reply; - socket.recv(&reply); - const std::string received((char *)reply.data (), (size_t)reply.size()); + socket.recv(reply); + BOOST_ASSERT(reply.more() && reply.to_string() == "m"); + + socket.recv(reply); + socket.set(zmq::sockopt::unsubscribe, "m"); + + // FIXME: for backward compatibility, we reconstruct the old format + rapidjson::Document document = mujinmasterslaveclient::DecodeFromFrame(reply); + const rapidjson::Document::ConstMemberIterator slaves = document.FindMember("slaves"); + size_t numSlaves = 0; + if (slaves != document.MemberEnd()) { + for (rapidjson::Document::ConstValueIterator iterator = slaves->value.Begin(); iterator != slaves->value.End(); ++iterator) { + std::string topic("s"); + topic += iterator->GetString(); + socket.set(zmq::sockopt::subscribe, topic); + ++numSlaves; + } + } + document.AddMember("slavestates", rapidjson::Value(rapidjson::kObjectType), document.GetAllocator()); + rapidjson::Value &slaveStates = document["slavestates"]; + + while (numSlaves > 0) { + zmq::poll(&pollitem, 1, 50); // wait 50 ms for message + if (!(pollitem.revents & ZMQ_POLLIN)) { + break; + } + + socket.recv(reply); + BOOST_ASSERT(reply.more()); + const std::string topic = reply.to_string(); + + socket.recv(reply); + socket.set(zmq::sockopt::unsubscribe, topic); + + GenericMsgpackParser parser(document.GetAllocator()); + if (!msgpack::parse(reply.data(), reply.size(), parser)) { + throw std::invalid_argument("unable to parse"); + } + + if (topic.empty() || topic[0] != 's') { + continue; + } + + const std::string slaveRequestId = "slaverequestid-" + topic.substr(1); + if (slaveStates.FindMember(rapidjson::StringRef(slaveRequestId.data(), slaveRequestId.size())) == slaveStates.MemberEnd()) { + --numSlaves; + slaveStates.AddMember( + rapidjson::Value(slaveRequestId.data(), slaveRequestId.size(), document.GetAllocator()), + parser.Extract(), document.GetAllocator()); + } + } + + const std::string received = DumpJson(document); #ifndef _WIN32 return received; #else diff --git a/src/binpickingtaskzmq.cpp b/src/binpickingtaskzmq.cpp index 2c97fb6b..6d9bfdf2 100644 --- a/src/binpickingtaskzmq.cpp +++ b/src/binpickingtaskzmq.cpp @@ -15,12 +15,12 @@ #include "common.h" #include "controllerclientimpl.h" #include "binpickingtaskzmq.h" -#include "mujincontrollerclient/mujinzmq.h" - #include // find +#include #include "logging.h" -#include "mujincontrollerclient/mujinjson.h" +#include "mujincontrollerclient/mujinjsonmsgpack.h" +#include "mujincontrollerclient/mujinmasterslaveclient.h" MUJIN_LOGGER("mujin.controllerclientcpp.binpickingtask.zmq"); @@ -30,26 +30,13 @@ namespace mujinclient { using namespace utils; using namespace mujinjson; -class ZmqMujinControllerClient : public mujinzmq::ZmqClient +class ZmqMujinControllerClient : public mujinmasterslaveclient::RequestSocket { public: - ZmqMujinControllerClient(boost::shared_ptr context, const std::string& host, const int port); - - virtual ~ZmqMujinControllerClient(); - + ZmqMujinControllerClient(zmq::context_t &context, const std::string& host, const int port): RequestSocket(context, (boost::format("tcp://%s:%d") % host % port).str()) {} }; -ZmqMujinControllerClient::ZmqMujinControllerClient(boost::shared_ptr context, const std::string& host, const int port) : ZmqClient(host, port) -{ - _InitializeSocket(context); -} - -ZmqMujinControllerClient::~ZmqMujinControllerClient() -{ - // _DestroySocket() is called in ~ZmqClient() -} - BinPickingTaskZmqResource::BinPickingTaskZmqResource(ControllerClientPtr c, const std::string& pk, const std::string& scenepk, const std::string& tasktype) : BinPickingTaskResource(c, pk, scenepk, tasktype) { _callerid = str(boost::format("controllerclientcpp%s_zmq")%MUJINCLIENT_VERSION_STRING); @@ -61,16 +48,14 @@ BinPickingTaskZmqResource::~BinPickingTaskZmqResource() void BinPickingTaskZmqResource::Initialize(const std::string& defaultTaskParameters, const int zmqPort, const int heartbeatPort, boost::shared_ptr zmqcontext, const bool initializezmq, const double reinitializetimeout, const double timeout, const std::string& userinfo, const std::string& slaverequestid) { + _zmqmujincontrollerclient.reset(); BinPickingTaskResource::Initialize(defaultTaskParameters, zmqPort, heartbeatPort, zmqcontext, initializezmq, reinitializetimeout, timeout, userinfo, slaverequestid); if (initializezmq) { InitializeZMQ(reinitializetimeout, timeout); } - _zmqmujincontrollerclient.reset(new ZmqMujinControllerClient(_zmqcontext, _mujinControllerIp, _zmqPort)); - if (!_zmqmujincontrollerclient) { - throw MujinException(boost::str(boost::format("Failed to establish ZMQ connection to mujin controller at %s:%d")%_mujinControllerIp%_zmqPort), MEC_Failed); - } + _zmqmujincontrollerclient = std::make_unique(*_zmqcontext, _mujinControllerIp, _zmqPort); if (!_pHeartbeatMonitorThread) { _bShutdownHeartbeatMonitor = false; if (reinitializetimeout > 0 ) { @@ -91,35 +76,9 @@ void _LogTaskParametersAndThrow(const std::string& taskparameters) { void BinPickingTaskZmqResource::ExecuteCommand(const std::string& taskparameters, rapidjson::Document &pt, const double timeout /* [sec] */, const bool getresult) { - std::stringstream ss; ss << std::setprecision(std::numeric_limits::digits10+1); - ss << "{\"fnname\": \""; - ss << (_tasktype == "binpicking" ? "binpicking.RunCommand\", " : "RunCommand\", "); - - ss << "\"stamp\": " << (GetMilliTime()*1e-3) << ", "; - ss << "\"callerid\": \"" << _GetCallerId() << "\", "; - ss << "\"taskparams\": {\"tasktype\": \"" << _tasktype << "\", "; - - ss << "\"taskparameters\": " << taskparameters << ", "; - ss << "\"sceneparams\": " << _sceneparams_json << "}, "; - ss << "\"userinfo\": " << _userinfo_json; - if (_slaverequestid != "") { - ss << ", " << GetJsonString("slaverequestid", _slaverequestid); - } - ss << "}"; - std::string result_ss; - - try{ - _ExecuteCommandZMQ(ss.str(), pt, timeout, getresult); - } - catch (const MujinException& e) { - MUJIN_LOG_ERROR(e.what()); - if (e.GetCode() == MEC_Timeout) { - _LogTaskParametersAndThrow(taskparameters); - } - else { - throw; - } - } + rapidjson::Document parsing; + ParseJson(parsing, taskparameters.data(), taskparameters.size()); + return ExecuteCommand(parsing, pt, timeout); } void BinPickingTaskZmqResource::ExecuteCommand(rapidjson::Value& rTaskParameters, rapidjson::Document& rOutput, const double timeout) @@ -148,78 +107,36 @@ void BinPickingTaskZmqResource::ExecuteCommand(rapidjson::Value& rTaskParameters rCommand.AddMember(rapidjson::Document::StringRefType("userinfo"), rUserInfo, rCommand.GetAllocator()); } + rapidjson::Value rServerCommand(rapidjson::kObjectType); if (!_slaverequestid.empty()) { - mujinjson::SetJsonValueByKey(rCommand, "slaverequestid", _slaverequestid); - } - - try { - _ExecuteCommandZMQ(mujinjson::DumpJson(rCommand), rOutput, timeout); - } - catch (const MujinException& e) { - MUJIN_LOG_ERROR(e.what()); - if (e.GetCode() == MEC_Timeout) { - _LogTaskParametersAndThrow(mujinjson::DumpJson(rCommand["taskparams"])); - } - else { - throw; - } + rServerCommand.AddMember("slaverequestid", rapidjson::Document::StringRefType(_slaverequestid.data(), _slaverequestid.size()), rCommand.GetAllocator()); } -} -void BinPickingTaskZmqResource::_ExecuteCommandZMQ(const std::string& command, rapidjson::Document& rOutput, const double timeout, const bool getresult) -{ if (!_bIsInitialized) { throw MujinException("BinPicking task is not initialized, please call Initialzie() first.", MEC_Failed); } - if (!_zmqmujincontrollerclient) { + if (_zmqmujincontrollerclient == nullptr) { MUJIN_LOG_ERROR("zmqcontrollerclient is not initialized! initialize"); - _zmqmujincontrollerclient.reset(new ZmqMujinControllerClient(_zmqcontext, _mujinControllerIp, _zmqPort)); + _zmqmujincontrollerclient = std::make_unique(*_zmqcontext, _mujinControllerIp, _zmqPort); } - std::string result_ss; + GenericMsgpackParser parser(rOutput.GetAllocator()); try { - result_ss = _zmqmujincontrollerclient->Call(command, timeout); + if (!mujinmasterslaveclient::SendAndReceive(*_zmqmujincontrollerclient, rServerCommand, rCommand, parser, std::chrono::duration_cast(std::chrono::duration(timeout)))) { + throw MujinException("Cannot deserialize response", MEC_InvalidState); + } } catch (const MujinException& e) { MUJIN_LOG_ERROR(e.what()); - if (e.GetCode() == MEC_ZMQNoResponse) { - MUJIN_LOG_INFO("reinitializing zmq connection with the slave"); - _zmqmujincontrollerclient.reset(new ZmqMujinControllerClient(_zmqcontext, _mujinControllerIp, _zmqPort)); - if (!_zmqmujincontrollerclient) { - throw MujinException(boost::str(boost::format("Failed to establish ZMQ connection to mujin controller at %s:%d")%_mujinControllerIp%_zmqPort), MEC_Failed); - } + if (e.GetCode() == MEC_Timeout) { + _LogTaskParametersAndThrow(mujinjson::DumpJson(rCommand["taskparams"])); } else { throw; } } - - try { - ParseJson(rOutput, result_ss); - } - catch(const std::exception& ex) { - MUJIN_LOG_ERROR(str(boost::format("Could not parse result %s")%result_ss)); - throw; - } - if( rOutput.IsObject() && rOutput.HasMember("error")) { - std::string error = GetJsonValueByKey(rOutput["error"], "errorcode"); - std::string description = GetJsonValueByKey(rOutput["error"], "description"); - if ( error.size() > 0 ) { - std::string serror; - if ( description.size() > 0 ) { - serror = description; - } - else { - serror = error; - } - if( serror.size() > 1000 ) { - MUJIN_LOG_ERROR(str(boost::format("truncated original error message from %d")%serror.size())); - serror = serror.substr(0,1000); - } - throw MujinException(str(boost::format("Error when calling binpicking.RunCommand: %s")%serror), MEC_BinPickingError); - } - } + parser.Extract().Swap(rOutput); } void BinPickingTaskZmqResource::InitializeZMQ(const double reinitializetimeout, const double timeout) @@ -230,8 +147,9 @@ void BinPickingTaskZmqResource::_HeartbeatMonitorThread(const double reinitializ { MUJIN_LOG_DEBUG(str(boost::format("starting controller %s monitoring thread on port %d for slaverequestid=%s.")%_mujinControllerIp%_heartbeatPort%_slaverequestid)); boost::shared_ptr socket; - BinPickingTaskResource::ResultHeartBeat heartbeat; - heartbeat._slaverequestid = _slaverequestid; + ResultGetBinpickingState taskstate; + std::vector buffer(1024 * 100); + rapidjson::Document::AllocatorType allocator(buffer.data(), buffer.size()); while (!_bShutdownHeartbeatMonitor) { if (!!socket) { socket->close(); @@ -245,7 +163,7 @@ void BinPickingTaskZmqResource::_HeartbeatMonitorThread(const double reinitializ std::stringstream ss; ss << std::setprecision(std::numeric_limits::digits10+1); ss << _heartbeatPort; socket->connect (("tcp://"+ _mujinControllerIp+":"+ss.str()).c_str()); - socket->setsockopt(ZMQ_SUBSCRIBE, "", 0); + socket->set(zmq::sockopt::subscribe, "s"+_slaverequestid); zmq::pollitem_t pollitem; memset(&pollitem, 0, sizeof(zmq::pollitem_t)); @@ -256,29 +174,41 @@ void BinPickingTaskZmqResource::_HeartbeatMonitorThread(const double reinitializ zmq::poll(&pollitem,1, 50); // wait 50 ms for message if (pollitem.revents & ZMQ_POLLIN) { zmq::message_t reply; - socket->recv(&reply); - std::string replystring((char *)reply.data (), (size_t)reply.size()); - rapidjson::Document pt(rapidjson::kObjectType); - try{ - std::stringstream replystring_ss(replystring); - ParseJson(pt, replystring_ss.str()); - heartbeat.Parse(pt); - { - boost::mutex::scoped_lock lock(_mutexTaskState); - _taskstate = heartbeat.taskstate; - } - //BINPICKING_LOG_ERROR(replystring); + socket->recv(reply); + if (!reply.more()) { + MUJIN_LOG_ERROR("unknown protocol"); + continue; + } + socket->recv(reply); + if (reply.more()) { + MUJIN_LOG_ERROR("unknown protocol"); + continue; + } - if (heartbeat.status != std::string("lost") && heartbeat.status.size() > 1) { - lastheartbeat = GetMilliTime(); + allocator.Clear(); + GenericMsgpackParser parser(allocator); + try { + if (!msgpack::parse(reply.data(), reply.size(), parser)) { + throw std::runtime_error("unable to parse"); + } + const rapidjson::Value pt = parser.Extract(); + const rapidjson::Value::ConstMemberIterator iterator = pt.FindMember("taskstate"); + if (iterator != pt.MemberEnd()) { + taskstate.Parse(iterator->value); } } catch (std::exception const &e) { - MUJIN_LOG_ERROR("HeartBeat reply is not JSON"); - MUJIN_LOG_ERROR(replystring); + MUJIN_LOG_ERROR("HeartBeat reply is not expected"); MUJIN_LOG_ERROR(e.what()); continue; } + { + boost::mutex::scoped_lock lock(_mutexTaskState); + _taskstate = taskstate; + } + //BINPICKING_LOG_ERROR(replystring); + + lastheartbeat = GetMilliTime(); } } if (!_bShutdownHeartbeatMonitor) { diff --git a/src/binpickingtaskzmq.h b/src/binpickingtaskzmq.h index 4a22ff90..835a957e 100644 --- a/src/binpickingtaskzmq.h +++ b/src/binpickingtaskzmq.h @@ -38,7 +38,6 @@ class MUJINCLIENT_API BinPickingTaskZmqResource : public BinPickingTaskResource void ExecuteCommand(const std::string& taskparameters, rapidjson::Document &pt, const double timeout /* [sec] */=0.0, const bool getresult=true) override; virtual void ExecuteCommand(rapidjson::Value& rTaskParameters, rapidjson::Document& rOutput, const double timeout /* second */=5.0) override; - void _ExecuteCommandZMQ(const std::string& command, rapidjson::Document& rOutput, const double timeout /* second */=5.0, const bool getresult=true); void Initialize(const std::string& defaultTaskParameters, const int zmqPort, const int heartbeatPort, boost::shared_ptr zmqcontext, const bool initializezmq=false, const double reinitializetimeout=10, const double timeout=0, const std::string& userinfo="{}", const std::string& slaverequestid="") override; @@ -46,7 +45,7 @@ class MUJINCLIENT_API BinPickingTaskZmqResource : public BinPickingTaskResource void _HeartbeatMonitorThread(const double reinitializetimeout, const double commandtimeout); private: - ZmqMujinControllerClientPtr _zmqmujincontrollerclient; + std::unique_ptr _zmqmujincontrollerclient; }; } // namespace mujinclient diff --git a/src/mujinmasterslaveclient.cpp b/src/mujinmasterslaveclient.cpp new file mode 100644 index 00000000..313c5a4e --- /dev/null +++ b/src/mujinmasterslaveclient.cpp @@ -0,0 +1,104 @@ +#include "mujincontrollerclient/mujinmasterslaveclient.h" +#include + +namespace mujinmasterslaveclient { +RequestSocket::RequestSocket(zmq::context_t& context, const std::string& address): socket_t(context, zmq::socket_type::req) +{ + // turn on tcp keepalive, do these configuration before bind + set(zmq::sockopt::tcp_keepalive, 1); + + // the interval between the last data packet sent (simple ACKs are not considered data) and the first + // keepalive probe; after the connection is marked to need keepalive, this counter is not used any further + set(zmq::sockopt::tcp_keepalive_idle, 2); + + // the interval between subsequent keepalive probes, regardless of what the connection + // has exchanged in the meantime + set(zmq::sockopt::tcp_keepalive_intvl, 2); + + // the number of unacknowledged probes to send before considering the connection dead + // and notifying the application layer + set(zmq::sockopt::tcp_keepalive_cnt, 2); + + connect(address); +} + +void RequestSocket::SendNoWait(std::vector&& messages) +{ + for (std::vector::iterator iterator = messages.begin(); iterator != messages.end(); ++iterator) { + if (!send(*iterator, iterator + 1 == messages.end() ? zmq::send_flags::dontwait : (zmq::send_flags::sndmore | zmq::send_flags::dontwait)).has_value()) { + throw mujinclient::MujinException("unable to send zmq message", mujinclient::MEC_InvalidState); + } + } +} + +std::vector RequestSocket::ReceiveNoWait() +{ + std::vector frames; + bool hasMore = true; + while (hasMore) { + frames.emplace_back(); + zmq::message_t& message = frames.back(); + if (!recv(message, zmq::recv_flags::dontwait).has_value()) { + throw mujinclient::MujinException("unable to receive zmq message", mujinclient::MEC_InvalidState); + } + hasMore = message.more(); + } + return frames; +} + +bool RequestSocket::Poll(const short events, const std::chrono::milliseconds timeout) +{ + zmq::pollitem_t item = { + .socket = handle(), + .events = events, + }; + return zmq::poll(&item, 1, timeout); +} + +std::vector RequestSocket::SendAndReceive( + std::vector&& messages, + std::chrono::milliseconds timeout) +{ + if (messages.empty()) { + throw mujinclient::MujinException("given messages is empty", mujinclient::MEC_InvalidArguments); + } + + if (timeout.count() < 0) { + throw mujinclient::MujinException("timeout cannot be negative", mujinclient::MEC_InvalidArguments); + } + + const std::chrono::steady_clock::time_point deadline = std::chrono::steady_clock::now() + timeout; + + if (!Poll(ZMQ_POLLOUT, timeout)) { + throw mujinclient::MujinException("timeout trying to send request", mujinclient::MEC_Timeout); + } + SendNoWait(std::move(messages)); + + timeout = std::chrono::duration_cast(deadline - std::chrono::steady_clock::now()); + if (timeout.count() < 0) { + throw mujinclient::MujinException("timeout trying to send request", mujinclient::MEC_Timeout); + } + + if (!Poll(ZMQ_POLLIN, timeout)) { + throw mujinclient::MujinException("timeout trying to receive response", mujinclient::MEC_Timeout); + } + std::vector response = ReceiveNoWait(); + assert(!response.empty()); // never going to happen + + const boost::string_view statusFrame = boost::string_view(response.front().data(), response.front().size()); + + if (statusFrame == "f") { + if (response.size() != 2) { + throw mujinclient::MujinException("unexpcted number of frames in error", mujinclient::MEC_InvalidState); + } + throw std::logic_error(response.back().to_string()); + } + + if (statusFrame != "t") { + throw mujinclient::MujinException("unexpected response protocol", mujinclient::MEC_InvalidState); + } + + response.erase(response.begin()); + return response; +} +}