From 57a3dcab33ac7662361be84e1328143da1e9e5d5 Mon Sep 17 00:00:00 2001 From: HolyLow Date: Tue, 21 Jan 2025 18:18:00 +0800 Subject: [PATCH] [CELEBORN-1845][CIP-14] Add MessageDispatcher to cppClient --- cpp/celeborn/network/CMakeLists.txt | 6 +- cpp/celeborn/network/Message.h | 4 +- cpp/celeborn/network/MessageDispatcher.cpp | 257 ++++++++++++++++++ cpp/celeborn/network/MessageDispatcher.h | 113 ++++++++ cpp/celeborn/network/tests/CMakeLists.txt | 3 +- .../network/tests/MessageDispatcherTest.cpp | 200 ++++++++++++++ 6 files changed, 579 insertions(+), 4 deletions(-) create mode 100644 cpp/celeborn/network/MessageDispatcher.cpp create mode 100644 cpp/celeborn/network/MessageDispatcher.h create mode 100644 cpp/celeborn/network/tests/MessageDispatcherTest.cpp diff --git a/cpp/celeborn/network/CMakeLists.txt b/cpp/celeborn/network/CMakeLists.txt index 3a65828bdb3..1acf114bb64 100644 --- a/cpp/celeborn/network/CMakeLists.txt +++ b/cpp/celeborn/network/CMakeLists.txt @@ -15,7 +15,8 @@ add_library( network STATIC - Message.cpp) + Message.cpp + MessageDispatcher.cpp) target_include_directories(network PUBLIC ${CMAKE_BINARY_DIR}) @@ -25,6 +26,9 @@ target_link_libraries( proto utils protocol + ${WANGLE} + ${FIZZ} + ${LIBSODIUM_LIBRARY} ${FOLLY_WITH_DEPENDENCIES} ${GLOG} ${GFLAGS_LIBRARIES} diff --git a/cpp/celeborn/network/Message.h b/cpp/celeborn/network/Message.h index a4b269aad07..ddb6d99799f 100644 --- a/cpp/celeborn/network/Message.h +++ b/cpp/celeborn/network/Message.h @@ -178,7 +178,7 @@ class RpcFailure : public Message { class ChunkFetchSuccess : public Message { public: ChunkFetchSuccess( - protocol::StreamChunkSlice& streamChunkSlice, + const protocol::StreamChunkSlice& streamChunkSlice, std::unique_ptr&& body) : Message(CHUNK_FETCH_SUCCESS, std::move(body)), streamChunkSlice_(streamChunkSlice) {} @@ -201,7 +201,7 @@ class ChunkFetchSuccess : public Message { class ChunkFetchFailure : public Message { public: ChunkFetchFailure( - protocol::StreamChunkSlice& streamChunkSlice, + const protocol::StreamChunkSlice& streamChunkSlice, std::string&& errorString) : Message( CHUNK_FETCH_FAILURE, diff --git a/cpp/celeborn/network/MessageDispatcher.cpp b/cpp/celeborn/network/MessageDispatcher.cpp new file mode 100644 index 00000000000..30161483dbd --- /dev/null +++ b/cpp/celeborn/network/MessageDispatcher.cpp @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "celeborn/network/MessageDispatcher.h" + +#include "celeborn/protocol/TransportMessage.h" + +namespace celeborn { +namespace network { +void MessageDispatcher::read(Context*, std::unique_ptr toRecvMsg) { + switch (toRecvMsg->type()) { + case Message::RPC_RESPONSE: { + RpcResponse* response = reinterpret_cast(toRecvMsg.get()); + bool found = true; + auto holder = requestIdRegistry_.withLock([&](auto& registry) { + auto search = registry.find(response->requestId()); + if (search == registry.end()) { + LOG(WARNING) + << "requestId " << response->requestId() + << " not found when handling RPC_RESPONSE. Might be outdated already, ignored."; + found = false; + return MsgPromiseHolder{}; + } + auto result = std::move(search->second); + registry.erase(response->requestId()); + return std::move(result); + }); + if (found) { + holder.msgPromise.setValue(std::move(toRecvMsg)); + } + return; + } + case Message::RPC_FAILURE: { + RpcFailure* failure = reinterpret_cast(toRecvMsg.get()); + bool found = true; + auto holder = requestIdRegistry_.withLock([&](auto& registry) { + auto search = registry.find(failure->requestId()); + if (search == registry.end()) { + LOG(WARNING) + << "requestId " << failure->requestId() + << " not found when handling RPC_FAILURE. Might be outdated already, ignored."; + found = false; + return MsgPromiseHolder{}; + } + auto result = std::move(search->second); + registry.erase(failure->requestId()); + return std::move(result); + }); + LOG(ERROR) << "Rpc failed, requestId: " << failure->requestId() + << " errorMsg: " << failure->errorMsg() << std::endl; + if (found) { + holder.msgPromise.setException( + folly::exception_wrapper(std::exception())); + } + return; + } + case Message::CHUNK_FETCH_SUCCESS: { + ChunkFetchSuccess* success = + reinterpret_cast(toRecvMsg.get()); + auto streamChunkSlice = success->streamChunkSlice(); + bool found = true; + auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) { + auto search = registry.find(streamChunkSlice); + if (search == registry.end()) { + LOG(WARNING) + << "streamChunkSlice " << streamChunkSlice.toString() + << " not found when handling CHUNK_FETCH_SUCCESS. Might be outdated already, ignored."; + found = false; + return MsgPromiseHolder{}; + } + auto result = std::move(search->second); + registry.erase(streamChunkSlice); + return std::move(result); + }); + if (found) { + holder.msgPromise.setValue(std::move(toRecvMsg)); + } + return; + } + case Message::CHUNK_FETCH_FAILURE: { + ChunkFetchFailure* failure = + reinterpret_cast(toRecvMsg.get()); + auto streamChunkSlice = failure->streamChunkSlice(); + bool found = true; + auto holder = streamChunkSliceRegistry_.withLock([&](auto& registry) { + auto search = registry.find(streamChunkSlice); + if (search == registry.end()) { + LOG(WARNING) + << "streamChunkSlice " << streamChunkSlice.toString() + << " not found when handling CHUNK_FETCH_FAILURE. Might be outdated already, ignored."; + found = false; + return MsgPromiseHolder{}; + } + auto result = std::move(search->second); + registry.erase(streamChunkSlice); + return std::move(result); + }); + std::string errorMsg = fmt::format( + "fetchChunk failed, streamChunkSlice: {}, errorMsg: {}", + streamChunkSlice.toString(), + failure->errorMsg()); + LOG(ERROR) << errorMsg; + if (found) { + holder.msgPromise.setException( + folly::exception_wrapper(std::exception())); + } + return; + } + default: { + LOG(ERROR) << "unsupported msg for dispatcher"; + } + } +} + +folly::Future> MessageDispatcher::operator()( + std::unique_ptr toSendMsg) { + CELEBORN_CHECK(!closed_); + CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST); + RpcRequest* request = reinterpret_cast(toSendMsg.get()); + auto f = requestIdRegistry_.withLock( + [&](auto& registry) -> folly::Future> { + auto& holder = registry[request->requestId()]; + holder.requestTime = std::chrono::system_clock::now(); + auto& p = holder.msgPromise; + p.setInterruptHandler([requestId = request->requestId(), + this](const folly::exception_wrapper&) { + this->requestIdRegistry_.lock()->erase(requestId); + LOG(WARNING) << "rpc request interrupted, requestId: " << requestId; + }); + return p.getFuture(); + }); + + this->pipeline_->write(std::move(toSendMsg)); + + CELEBORN_CHECK(!closed_); + return f; +} + +folly::Future> +MessageDispatcher::sendFetchChunkRequest( + const protocol::StreamChunkSlice& streamChunkSlice, + std::unique_ptr toSendMsg) { + CELEBORN_CHECK(!closed_); + CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST); + auto f = streamChunkSliceRegistry_.withLock([&](auto& registry) { + auto& holder = registry[streamChunkSlice]; + holder.requestTime = std::chrono::system_clock::now(); + auto& p = holder.msgPromise; + p.setInterruptHandler( + [streamChunkSlice, this](const folly::exception_wrapper&) { + LOG(WARNING) << "fetchChunk request interrupted, " + "streamChunkSlice: " + << streamChunkSlice.toString(); + this->streamChunkSliceRegistry_.lock()->erase(streamChunkSlice); + }); + return p.getFuture(); + }); + this->pipeline_->write(std::move(toSendMsg)); + CELEBORN_CHECK(!closed_); + return f; +} + +void MessageDispatcher::sendRpcRequestWithoutResponse( + std::unique_ptr toSendMsg) { + CELEBORN_CHECK(toSendMsg->type() == Message::RPC_REQUEST); + this->pipeline_->write(std::move(toSendMsg)); +} + +void MessageDispatcher::readEOF(Context* ctx) { + LOG(ERROR) << "readEOF, start to close client"; + ctx->fireReadEOF(); + close(); +} + +void MessageDispatcher::readException( + Context* ctx, + folly::exception_wrapper e) { + LOG(ERROR) << "readException: " << e.what() << " , start to close client"; + ctx->fireReadException(std::move(e)); + close(); +} + +void MessageDispatcher::transportActive(Context* ctx) { + // Typically do nothing. + ctx->fireTransportActive(); +} + +void MessageDispatcher::transportInactive(Context* ctx) { + LOG(ERROR) << "transportInactive, start to close client"; + ctx->fireTransportInactive(); + close(); +} + +folly::Future MessageDispatcher::writeException( + Context* ctx, + folly::exception_wrapper e) { + LOG(ERROR) << "writeException: " << e.what() << " , start to close client"; + auto result = ctx->fireWriteException(std::move(e)); + close(); + return result; +} + +folly::Future MessageDispatcher::close() { + if (!closed_) { + closed_ = true; + cleanup(); + } + return ClientDispatcherBase::close(); +} + +folly::Future MessageDispatcher::close(Context* ctx) { + if (!closed_) { + closed_ = true; + cleanup(); + } + + return ClientDispatcherBase::close(ctx); +} + +void MessageDispatcher::cleanup() { + LOG(WARNING) << "Cleaning up client!"; + requestIdRegistry_.withLock([&](auto& registry) { + for (auto& [requestId, promiseHolder] : registry) { + auto errorMsg = + fmt::format("Client closed, cancel ongoing requestId {}", requestId); + LOG(WARNING) << errorMsg; + promiseHolder.msgPromise.setException(std::runtime_error(errorMsg)); + } + registry.clear(); + }); + streamChunkSliceRegistry_.withLock([&](auto& registry) { + for (auto& [streamChunkSlice, promiseHolder] : registry) { + auto errorMsg = fmt::format( + "Client closed, cancel ongoing streamChunkSlice {}", + streamChunkSlice.toString()); + LOG(WARNING) << errorMsg; + promiseHolder.msgPromise.setException(std::runtime_error(errorMsg)); + } + registry.clear(); + }); +} +} // namespace network +} // namespace celeborn diff --git a/cpp/celeborn/network/MessageDispatcher.h b/cpp/celeborn/network/MessageDispatcher.h new file mode 100644 index 00000000000..24a233a7c69 --- /dev/null +++ b/cpp/celeborn/network/MessageDispatcher.h @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "celeborn/conf/CelebornConf.h" +#include "celeborn/network/Message.h" +#include "celeborn/protocol/ControlMessages.h" +#include "celeborn/utils/CelebornUtils.h" + +namespace celeborn { +namespace network { +using SerializePipeline = + wangle::Pipeline>; + +/** + * MessageDispatcher is responsible for: + * 1. Record the connection between MessageFuture and MessagePromise. + * When a request message is sent via write(), the MessageFuture is + * recorded; then when the response message is received via read(), + * the response would be transferred to MessageFuture by fulfilling + * the corresponding MessagePromise. + * 2. Send different messages via different interfaces, and calls + * write() to send it to the network layer. A MessagePromise is + * created and recorded for each returned MessageFuture. + * 3. Receive response messages via read(), and dispatch the message + * according to the message kind, and finally fulfills the + * corresponding MessagePromise. + * 4. Handles and reports all kinds of network issues, e.g. EOF, + * inactive, exception, etc. + */ +class MessageDispatcher : public wangle::ClientDispatcherBase< + SerializePipeline, + std::unique_ptr, + std::unique_ptr> { +public: + void read(Context*, std::unique_ptr toRecvMsg) override; + + virtual folly::Future> sendRpcRequest( + std::unique_ptr toSendMsg) { + return operator()(std::move(toSendMsg)); + } + + virtual folly::Future> sendFetchChunkRequest( + const protocol::StreamChunkSlice& streamChunkSlice, + std::unique_ptr toSendMsg); + + virtual void sendRpcRequestWithoutResponse( + std::unique_ptr toSendMsg); + + folly::Future> operator()( + std::unique_ptr toSendMsg) override; + + void readEOF(Context* ctx) override; + + void readException(Context* ctx, folly::exception_wrapper e) override; + + void transportActive(Context* ctx) override; + + void transportInactive(Context* ctx) override; + + folly::Future writeException( + Context* ctx, + folly::exception_wrapper e) override; + + folly::Future close() override; + + folly::Future close(Context* ctx) override; + + bool isAvailable() override { + return !closed_; + } + +private: + void cleanup(); + + using MsgPromise = folly::Promise>; + struct MsgPromiseHolder { + MsgPromise msgPromise; + std::chrono::time_point requestTime; + }; + folly::Synchronized, std::mutex> + requestIdRegistry_; + folly::Synchronized< + std::unordered_map< + protocol::StreamChunkSlice, + MsgPromiseHolder, + protocol::StreamChunkSlice::Hasher>, + std::mutex> + streamChunkSliceRegistry_; + std::atomic closed_{false}; +}; +} // namespace network +} // namespace celeborn diff --git a/cpp/celeborn/network/tests/CMakeLists.txt b/cpp/celeborn/network/tests/CMakeLists.txt index db38fb48497..fdd2c874e22 100644 --- a/cpp/celeborn/network/tests/CMakeLists.txt +++ b/cpp/celeborn/network/tests/CMakeLists.txt @@ -16,7 +16,8 @@ add_executable( celeborn_network_test FrameDecoderTest.cpp - MessageTest.cpp) + MessageTest.cpp + MessageDispatcherTest.cpp) add_test(NAME celeborn_network_test COMMAND celeborn_network_test) diff --git a/cpp/celeborn/network/tests/MessageDispatcherTest.cpp b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp new file mode 100644 index 00000000000..8baf9f8d76a --- /dev/null +++ b/cpp/celeborn/network/tests/MessageDispatcherTest.cpp @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "celeborn/network/FrameDecoder.h" +#include "celeborn/network/MessageDispatcher.h" + +using namespace celeborn; +using namespace celeborn::network; + +namespace { +class MockHandler : public wangle::Handler< + std::unique_ptr, + std::unique_ptr, + std::unique_ptr, + std::unique_ptr> { + public: + MockHandler(std::unique_ptr& writedMsg) : writedMsg_(writedMsg) {} + + void read(Context* ctx, std::unique_ptr msg) override {} + + folly::Future write(Context* ctx, std::unique_ptr msg) + override { + writedMsg_ = std::move(msg); + return {}; + } + + private: + std::unique_ptr& writedMsg_; +}; + +SerializePipeline::Ptr createMockedPipeline(MockHandler&& mockHandler) { + auto pipeline = SerializePipeline::create(); + // FrameDecoder here is just for forming a complete pipeline to pass + // the type checking, not used here. + pipeline->addBack(FrameDecoder()); + pipeline->addBack(std::move(mockHandler)); + pipeline->finalize(); + return pipeline; +} + +std::unique_ptr toReadOnlyByteBuffer( + const std::string& content) { + auto buffer = memory::ByteBuffer::createWriteOnly(content.size()); + buffer->writeFromString(content); + return memory::ByteBuffer::toReadOnly(std::move(buffer)); +} + +} // namespace + +TEST(MessageDispatcherTest, sendRpcRequestAndReceiveResponse) { + std::unique_ptr sendedMsg; + MockHandler mockHandler(sendedMsg); + auto mockPipeline = createMockedPipeline(std::move(mockHandler)); + auto dispatcher = std::make_unique(); + dispatcher->setPipeline(mockPipeline.get()); + + const long requestId = 1001; + const std::string requestBody = "test-request-body"; + auto rpcRequest = std::make_unique( + requestId, toReadOnlyByteBuffer(requestBody)); + auto future = dispatcher->sendRpcRequest(std::move(rpcRequest)); + + EXPECT_FALSE(future.isReady()); + EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST); + auto sendedRpcRequest = dynamic_cast(sendedMsg.get()); + EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size()); + EXPECT_EQ( + sendedRpcRequest->body()->readToString(requestBody.size()), requestBody); + + const std::string responseBody = "test-response-body"; + auto rpcResponse = std::make_unique( + requestId, toReadOnlyByteBuffer(responseBody)); + dispatcher->read(nullptr, std::move(rpcResponse)); + + EXPECT_TRUE(future.isReady()); + auto receivedMsg = std::move(future).get(); + EXPECT_EQ(receivedMsg->type(), Message::RPC_RESPONSE); + auto receivedRpcResponse = dynamic_cast(receivedMsg.get()); + EXPECT_EQ(receivedRpcResponse->body()->remainingSize(), responseBody.size()); + EXPECT_EQ( + receivedRpcResponse->body()->readToString(responseBody.size()), + responseBody); +} + +TEST(MessageDispatcherTest, sendRpcRequestAndReceiveFailure) { + std::unique_ptr sendedMsg; + MockHandler mockHandler(sendedMsg); + auto mockPipeline = createMockedPipeline(std::move(mockHandler)); + auto dispatcher = std::make_unique(); + dispatcher->setPipeline(mockPipeline.get()); + + const long requestId = 1001; + const std::string requestBody = "test-request-body"; + auto rpcRequest = std::make_unique( + requestId, toReadOnlyByteBuffer(requestBody)); + auto future = dispatcher->sendRpcRequest(std::move(rpcRequest)); + + EXPECT_FALSE(future.isReady()); + EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST); + auto sendedRpcRequest = dynamic_cast(sendedMsg.get()); + EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size()); + EXPECT_EQ( + sendedRpcRequest->body()->readToString(requestBody.size()), requestBody); + + const std::string errorMsg = "test-error-msg"; + auto copiedErrorMsg = errorMsg; + auto rpcFailure = + std::make_unique(requestId, std::move(copiedErrorMsg)); + dispatcher->read(nullptr, std::move(rpcFailure)); + + EXPECT_TRUE(future.hasException()); +} + +TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveSuccess) { + std::unique_ptr sendedMsg; + MockHandler mockHandler(sendedMsg); + auto mockPipeline = createMockedPipeline(std::move(mockHandler)); + auto dispatcher = std::make_unique(); + dispatcher->setPipeline(mockPipeline.get()); + + const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004}; + const long requestId = 1001; + const std::string requestBody = "test-request-body"; + auto rpcRequest = std::make_unique( + requestId, toReadOnlyByteBuffer(requestBody)); + auto future = dispatcher->sendFetchChunkRequest( + streamChunkSlice, std::move(rpcRequest)); + + EXPECT_FALSE(future.isReady()); + EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST); + auto sendedRpcRequest = dynamic_cast(sendedMsg.get()); + EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size()); + EXPECT_EQ( + sendedRpcRequest->body()->readToString(requestBody.size()), requestBody); + + const std::string chunkFetchSuccessBody = "test-chunk-fetch-success-body"; + auto chunkFetchSuccess = std::make_unique( + streamChunkSlice, toReadOnlyByteBuffer(chunkFetchSuccessBody)); + dispatcher->read(nullptr, std::move(chunkFetchSuccess)); + + EXPECT_TRUE(future.isReady()); + auto receivedMsg = std::move(future).get(); + EXPECT_EQ(receivedMsg->type(), Message::CHUNK_FETCH_SUCCESS); + auto receivedChunkFetchSuccess = + dynamic_cast(receivedMsg.get()); + EXPECT_EQ( + receivedChunkFetchSuccess->body()->remainingSize(), + chunkFetchSuccessBody.size()); + EXPECT_EQ( + receivedChunkFetchSuccess->body()->readToString( + chunkFetchSuccessBody.size()), + chunkFetchSuccessBody); +} + +TEST(MessageDispatcherTest, sendFetchChunkRequestAndReceiveFailure) { + std::unique_ptr sendedMsg; + MockHandler mockHandler(sendedMsg); + auto mockPipeline = createMockedPipeline(std::move(mockHandler)); + auto dispatcher = std::make_unique(); + dispatcher->setPipeline(mockPipeline.get()); + + const protocol::StreamChunkSlice streamChunkSlice{1001, 1002, 1003, 1004}; + const long requestId = 1001; + const std::string requestBody = "test-request-body"; + auto rpcRequest = std::make_unique( + requestId, toReadOnlyByteBuffer(requestBody)); + auto future = dispatcher->sendFetchChunkRequest( + streamChunkSlice, std::move(rpcRequest)); + + EXPECT_FALSE(future.isReady()); + EXPECT_EQ(sendedMsg->type(), Message::RPC_REQUEST); + auto sendedRpcRequest = dynamic_cast(sendedMsg.get()); + EXPECT_EQ(sendedRpcRequest->body()->remainingSize(), requestBody.size()); + EXPECT_EQ( + sendedRpcRequest->body()->readToString(requestBody.size()), requestBody); + + const std::string errorMsg = "test-error-msg"; + auto copiedErrorMsg = errorMsg; + auto chunkFetchFailure = std::make_unique( + streamChunkSlice, std::move(copiedErrorMsg)); + dispatcher->read(nullptr, std::move(chunkFetchFailure)); + + EXPECT_TRUE(future.hasException()); +}