Skip to content

Commit

Permalink
Pre shutdown hooks for GatewayClient (#5381)
Browse files Browse the repository at this point in the history
  • Loading branch information
durch authored Jan 27, 2025
1 parent ff91d46 commit 9550934
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 50 deletions.
18 changes: 10 additions & 8 deletions common/client-core/src/client/base_client/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2022-2023 - Nym Technologies SA <[email protected]>
// SPDX-License-Identifier: Apache-2.0

use super::mix_traffic::ClientRequestSender;
use super::received_buffer::ReceivedBufferMessage;
use super::statistics_control::StatisticsControl;
use crate::client::base_client::storage::helpers::store_client_keys;
Expand Down Expand Up @@ -645,13 +646,12 @@ where
fn start_mix_traffic_controller(
gateway_transceiver: Box<dyn GatewayTransceiver + Send>,
shutdown: TaskClient,
forget_me: ForgetMe,
) -> BatchMixMessageSender {
) -> (BatchMixMessageSender, ClientRequestSender) {
info!("Starting mix traffic controller...");
let (mix_traffic_controller, mix_tx) =
MixTrafficController::new(gateway_transceiver, forget_me);
let (mix_traffic_controller, mix_tx, client_tx) =
MixTrafficController::new(gateway_transceiver);
mix_traffic_controller.start_with_shutdown(shutdown);
mix_tx
(mix_tx, client_tx)
}

// TODO: rename it as it implies the data is persistent whilst one can use InMemBackend
Expand Down Expand Up @@ -833,10 +833,9 @@ where
// traffic stream.
// The MixTrafficController then sends the actual traffic

let message_sender = Self::start_mix_traffic_controller(
let (message_sender, client_request_sender) = Self::start_mix_traffic_controller(
gateway_transceiver,
shutdown.fork("mix_traffic_controller"),
self.forget_me,
);

// Channels that the websocket listener can use to signal downstream to the real traffic
Expand Down Expand Up @@ -911,6 +910,8 @@ where
},
stats_reporter,
task_handle: shutdown,
client_request_sender,
forget_me: self.forget_me,
})
}
}
Expand All @@ -922,6 +923,7 @@ pub struct BaseClient {
pub client_output: ClientOutputStatus,
pub client_state: ClientState,
pub stats_reporter: ClientStatsSender,

pub client_request_sender: ClientRequestSender,
pub task_handle: TaskHandle,
pub forget_me: ForgetMe,
}
60 changes: 33 additions & 27 deletions common/client-core/src/client/mix_traffic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// SPDX-License-Identifier: Apache-2.0

use crate::client::mix_traffic::transceiver::GatewayTransceiver;
use crate::{spawn_future, ForgetMe};
use crate::spawn_future;
use log::*;
use nym_gateway_requests::ClientRequest;
use nym_sphinx::forwarding::packet::MixPacket;

pub type BatchMixMessageSender = tokio::sync::mpsc::Sender<Vec<MixPacket>>;
pub type BatchMixMessageReceiver = tokio::sync::mpsc::Receiver<Vec<MixPacket>>;
pub type ClientRequestReceiver = tokio::sync::mpsc::Receiver<ClientRequest>;
pub type ClientRequestSender = tokio::sync::mpsc::Sender<ClientRequest>;

pub mod transceiver;

Expand All @@ -23,48 +25,60 @@ pub struct MixTrafficController {
gateway_transceiver: Box<dyn GatewayTransceiver + Send>,

mix_rx: BatchMixMessageReceiver,
client_rx: ClientRequestReceiver,

// TODO: this is temporary work-around.
// in long run `gateway_client` will be moved away from `MixTrafficController` anyway.
consecutive_gateway_failure_count: usize,
forget_me: ForgetMe,
}

impl MixTrafficController {
pub fn new<T>(
gateway_transceiver: T,
forget_me: ForgetMe,
) -> (MixTrafficController, BatchMixMessageSender)
) -> (
MixTrafficController,
BatchMixMessageSender,
ClientRequestSender,
)
where
T: GatewayTransceiver + Send + 'static,
{
let (message_sender, message_receiver) =
tokio::sync::mpsc::channel(MIX_MESSAGE_RECEIVER_BUFFER_SIZE);

let (client_sender, client_receiver) = tokio::sync::mpsc::channel(1);

(
MixTrafficController {
gateway_transceiver: Box::new(gateway_transceiver),
mix_rx: message_receiver,
client_rx: client_receiver,
consecutive_gateway_failure_count: 0,
forget_me,
},
message_sender,
client_sender,
)
}

pub fn new_dynamic(
gateway_transceiver: Box<dyn GatewayTransceiver + Send>,
forget_me: ForgetMe,
) -> (MixTrafficController, BatchMixMessageSender) {
) -> (
MixTrafficController,
BatchMixMessageSender,
ClientRequestSender,
) {
let (message_sender, message_receiver) =
tokio::sync::mpsc::channel(MIX_MESSAGE_RECEIVER_BUFFER_SIZE);
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(1);
(
MixTrafficController {
gateway_transceiver,
mix_rx: message_receiver,
client_rx: client_receiver,
consecutive_gateway_failure_count: 0,
forget_me,
},
message_sender,
client_sender,
)
}

Expand Down Expand Up @@ -112,6 +126,17 @@ impl MixTrafficController {
break;
}
},
client_request = self.client_rx.recv() => match client_request {
Some(client_request) => {
match self.gateway_transceiver.send_client_request(client_request).await {
Ok(_) => (),
Err(e) => error!("Failed to send client request: {}", e),
};
},
None => {
log::trace!("MixTrafficController, client request channel closed");
}
},
_ = shutdown.recv_with_delay() => {
log::trace!("MixTrafficController: Received shutdown");
break;
Expand All @@ -120,25 +145,6 @@ impl MixTrafficController {
}
shutdown.recv_timeout().await;

if self.forget_me.any() {
log::info!("Sending forget me request to the gateway");
match self
.gateway_transceiver
.send_client_request(ClientRequest::ForgetMe {
client: self.forget_me.client(),
stats: self.forget_me.stats(),
})
.await
{
Ok(_) => {
log::info!("Successfully sent forget me request to the gateway");
}
Err(err) => {
log::error!("Failed to send forget me request to the gateway: {err}");
}
}
}

log::debug!("MixTrafficController: Exiting");
});
}
Expand Down
13 changes: 4 additions & 9 deletions common/client-core/src/client/mix_traffic/transceiver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ impl<G: GatewayTransceiver + ?Sized + Send> GatewayTransceiver for Box<G> {
&mut self,
message: ClientRequest,
) -> Result<(), GatewayClientError> {
(**self).send_client_request(message).await
let _ = (**self).send_client_request(message.clone()).await?;
log::debug!("Sent client request: {:?}", message);
Ok(())
}
}

Expand Down Expand Up @@ -143,14 +145,7 @@ where
&mut self,
message: ClientRequest,
) -> Result<(), GatewayClientError> {
if let Some(shared_key) = self.gateway_client.shared_key() {
self.gateway_client
.send_websocket_message(message.encrypt(&*shared_key)?)
.await?;
Ok(())
} else {
Err(GatewayClientError::ConnectionInInvalidState)
}
self.gateway_client.send_client_request(message).await
}
}

Expand Down
13 changes: 13 additions & 0 deletions common/client-libs/gateway-client/src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,19 @@ impl<C, St> GatewayClient<C, St> {
}
}

pub async fn send_client_request(
&mut self,
message: ClientRequest,
) -> Result<(), GatewayClientError> {
if let Some(shared_key) = self.shared_key() {
let encrypted = message.encrypt(&*shared_key)?;
Box::pin(self.send_websocket_message(encrypted)).await?;
Ok(())
} else {
Err(GatewayClientError::ConnectionInInvalidState)
}
}

async fn read_control_response(&mut self) -> Result<ServerResponse, GatewayClientError> {
// we use the fact that all request responses are Message::Text and only pushed
// sphinx packets are Message::Binary
Expand Down
2 changes: 1 addition & 1 deletion common/gateway-requests/src/types/text_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::str::FromStr;
use tungstenite::Message;

// wrapper for all encrypted requests for ease of use
#[derive(Serialize, Deserialize, Debug)]
#[derive(Serialize, Deserialize, Debug, Clone)]
#[non_exhaustive]
pub enum ClientRequest {
UpgradeKey {
Expand Down
2 changes: 1 addition & 1 deletion common/task/src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl TaskClient {
const MAX_NAME_LENGTH: usize = 128;
const OVERFLOW_NAME: &'static str = "reached maximum TaskClient children name depth";

const SHUTDOWN_TIMEOUT_WAITING_FOR_SIGNAL_ON_EXIT: Duration = Duration::from_secs(5);
const SHUTDOWN_TIMEOUT_WAITING_FOR_SIGNAL_ON_EXIT: Duration = Duration::from_secs(10);

fn new(
notify: watch::Receiver<()>,
Expand Down
16 changes: 12 additions & 4 deletions nym-network-monitor/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ async fn make_clients(
loop {
if Arc::strong_count(&dropped_client) == 1 {
if let Some(client) = Arc::into_inner(dropped_client) {
// let forget_me = ClientRequest::ForgetMe {
// also_from_stats: true,
// };
let client_handle = client.into_inner();
client_handle.disconnect().await;
} else {
Expand Down Expand Up @@ -222,10 +219,12 @@ async fn main() -> Result<()> {
TOPOLOGY.get().expect("Topology not set yet!").clone(),
));

let clients_server = clients.clone();

let server_handle = tokio::spawn(async move {
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::from_str(&args.host)?), args.port);
let server = HttpServer::new(socket, server_cancel_token);
server.run(clients).await
server.run(clients_server).await
});

info!("Waiting for message (ctrl-c to exit)");
Expand Down Expand Up @@ -259,6 +258,15 @@ async fn main() -> Result<()> {
};
}

info!("Disconnecting all clients");
let mut clients_guard = clients.write().await;
while let Some(client) = clients_guard.pop_front() {
if let Some(client) = Arc::into_inner(client) {
let client_handle = client.into_inner();
client_handle.disconnect().await;
}
}

cancel_token.cancel();

server_handle.await??;
Expand Down
2 changes: 2 additions & 0 deletions sdk/rust/nym-sdk/src/mixnet/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,6 +800,8 @@ where
stats_events_reporter,
started_client.task_handle,
None,
started_client.client_request_sender,
started_client.forget_me,
))
}
}
Expand Down
36 changes: 36 additions & 0 deletions sdk/rust/nym-sdk/src/mixnet/native_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ use async_trait::async_trait;
use futures::{ready, Stream, StreamExt};
use log::error;
use nym_client_core::client::base_client::GatewayConnection;
use nym_client_core::client::mix_traffic::ClientRequestSender;
use nym_client_core::client::{
base_client::{ClientInput, ClientOutput, ClientState},
inbound_messages::InputMessage,
received_buffer::ReconstructedMessagesReceiver,
};
use nym_client_core::ForgetMe;
use nym_crypto::asymmetric::identity;
use nym_gateway_requests::ClientRequest;
use nym_sphinx::addressing::clients::Recipient;
use nym_sphinx::{params::PacketType, receiver::ReconstructedMessage};
use nym_statistics_common::clients::{ClientStatsEvents, ClientStatsSender};
Expand Down Expand Up @@ -56,6 +59,8 @@ pub struct MixnetClient {

// internal state used for the `Stream` implementation
_buffered: Vec<ReconstructedMessage>,
pub(crate) client_request_sender: ClientRequestSender,
pub(crate) forget_me: ForgetMe,
}

impl MixnetClient {
Expand All @@ -70,6 +75,8 @@ impl MixnetClient {
stats_events_reporter: ClientStatsSender,
task_handle: TaskHandle,
packet_type: Option<PacketType>,
client_request_sender: ClientRequestSender,
forget_me: ForgetMe,
) -> Self {
Self {
nym_address,
Expand All @@ -82,6 +89,8 @@ impl MixnetClient {
task_handle,
packet_type,
_buffered: Vec::new(),
client_request_sender,
forget_me,
}
}

Expand Down Expand Up @@ -112,6 +121,10 @@ impl MixnetClient {
&self.nym_address
}

pub fn client_request_sender(&self) -> ClientRequestSender {
self.client_request_sender.clone()
}

/// Sign a message with the client's private identity key.
pub fn sign(&self, data: &[u8]) -> identity::Signature {
self.identity_keys.private_key().sign(data)
Expand Down Expand Up @@ -201,6 +214,15 @@ impl MixnetClient {
/// Disconnect from the mixnet. Currently it is not supported to reconnect a disconnected
/// client.
pub async fn disconnect(mut self) {
if self.forget_me.any() {
log::debug!("Sending forget me request: {:?}", self.forget_me);
match self.send_forget_me().await {
Ok(_) => (),
Err(e) => error!("Failed to send forget me request: {}", e),
};
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
}

if let TaskHandle::Internal(task_manager) = &mut self.task_handle {
task_manager.signal_shutdown().ok();
task_manager.wait_for_shutdown().await;
Expand All @@ -209,6 +231,20 @@ impl MixnetClient {
// note: it's important to take ownership of the struct as if the shutdown is `TaskHandle::External`,
// it must be dropped to finalize the shutdown
}

pub async fn send_forget_me(&self) -> Result<()> {
let client_request = ClientRequest::ForgetMe {
client: self.forget_me.client(),
stats: self.forget_me.stats(),
};
match self.client_request_sender.send(client_request).await {
Ok(_) => Ok(()),
Err(e) => {
error!("Failed to send forget me request: {}", e);
Err(Error::MessageSendingFailure)
}
}
}
}

#[derive(Clone)]
Expand Down

0 comments on commit 9550934

Please sign in to comment.