Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pubsub): Implement Ping/Pong Mechanism to Improve Connection Reliability #3845

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions pubsub-client/src/nonblocking/pubsub_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
//! ```

use {
crate::pubsub_client::{DEFAULT_MAX_FAILED_PINGS, DEFAULT_PING_DURATION_SECONDS},
futures_util::{
future::{ready, BoxFuture, FutureExt},
sink::SinkExt,
Expand Down Expand Up @@ -197,7 +198,7 @@ use {
net::TcpStream,
sync::{mpsc, oneshot},
task::JoinHandle,
time::{sleep, Duration},
time::{interval, Duration, Interval},
},
tokio_stream::wrappers::UnboundedReceiverStream,
tokio_tungstenite::{
Expand Down Expand Up @@ -249,6 +250,16 @@ pub enum PubsubClientError {
UnexpectedGetVersionResponse(String),
}

impl PubsubClientError {
pub fn is_timeout(&self) -> bool {
matches!(
self,
PubsubClientError::WsError(tungstenite::Error::Io(ref err))
if err.kind() == std::io::ErrorKind::WouldBlock
)
}
}

type UnsubscribeFn = Box<dyn FnOnce() -> BoxFuture<'static, ()> + Send>;
type SubscribeResponseMsg =
Result<(mpsc::UnboundedReceiver<Value>, UnsubscribeFn), PubsubClientError>;
Expand Down Expand Up @@ -500,6 +511,10 @@ impl PubsubClient {
let mut subscriptions = BTreeMap::new();
let (unsubscribe_sender, mut unsubscribe_receiver) = mpsc::unbounded_channel();

let mut ping_interval: Interval =
interval(Duration::from_secs(DEFAULT_PING_DURATION_SECONDS));
let mut elapsed_pings: usize = 0usize;

loop {
tokio::select! {
// Send close on shutdown signal
Expand All @@ -510,8 +525,15 @@ impl PubsubClient {
break;
},
// Send `Message::Ping` each 10s if no any other communication
() = sleep(Duration::from_secs(10)) => {
_ = ping_interval.tick() => {
ws.send(Message::Ping(Vec::new())).await?;
elapsed_pings += 1;

if elapsed_pings > DEFAULT_MAX_FAILED_PINGS {
info!("No pong received after {} pings. Closing connection...", DEFAULT_MAX_FAILED_PINGS);
ws.close(Some(CloseFrame { code: CloseCode::Normal, reason: "No pong received".into() })).await?;
break;
}
},
// Read message for subscribe
Some((operation, params, response_sender)) = subscribe_receiver.recv() => {
Expand Down Expand Up @@ -547,13 +569,20 @@ impl PubsubClient {

// Get text from the message
let text = match msg {
Message::Text(text) => text,
Message::Text(text) => {
elapsed_pings = 0;
text
},
Message::Binary(_data) => continue, // Ignore
Message::Ping(data) => {
ws.send(Message::Pong(data)).await?;
elapsed_pings = 0;
continue
},
Message::Pong(_data) => {
elapsed_pings = 0;
continue
},
Message::Pong(_data) => continue,
Message::Close(_frame) => break,
Message::Frame(_frame) => continue,
};
Expand Down
87 changes: 71 additions & 16 deletions pubsub-client/src/pubsub_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ use {
marker::PhantomData,
net::TcpStream,
sync::{
atomic::{AtomicBool, Ordering},
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, RwLock,
},
thread::{sleep, JoinHandle},
Expand All @@ -125,6 +125,11 @@ use {
url::Url,
};

/// The interval between pings measured in seconds
pub const DEFAULT_PING_DURATION_SECONDS: u64 = 10;
/// The maximum number of consecutive failed pings before considering the connection stale
pub const DEFAULT_MAX_FAILED_PINGS: usize = 3;

0xIchigo marked this conversation as resolved.
Show resolved Hide resolved
/// A subscription.
///
/// The subscription is unsubscribed on drop, and note that unsubscription (and
Expand Down Expand Up @@ -211,24 +216,29 @@ where
fn read_message(
writable_socket: &Arc<RwLock<WebSocket<MaybeTlsStream<TcpStream>>>>,
) -> Result<Option<T>, PubsubClientError> {
let message = writable_socket.write().unwrap().read()?;
if message.is_ping() {
return Ok(None);
}
let message_text = &message.into_text()?;
if let Ok(json_msg) = serde_json::from_str::<Map<String, Value>>(message_text) {
if let Some(Object(params)) = json_msg.get("params") {
if let Some(result) = params.get("result") {
if let Ok(x) = serde_json::from_value::<T>(result.clone()) {
return Ok(Some(x));
match writable_socket.write().unwrap().read() {
Ok(message) => {
if message.is_ping() || message.is_pong() {
return Ok(None);
}

let message_text = &message.into_text()?;
if let Ok(json_msg) = serde_json::from_str::<Map<String, Value>>(message_text) {
if let Some(Object(params)) = json_msg.get("params") {
if let Some(result) = params.get("result") {
if let Ok(x) = serde_json::from_value::<T>(result.clone()) {
return Ok(Some(x));
}
}
}
}

Err(PubsubClientError::UnexpectedMessageError(format!(
"msg={message_text}"
)))
}
Err(err) => Err(PubsubClientError::WsError(err)),
}

Err(PubsubClientError::UnexpectedMessageError(format!(
"msg={message_text}"
)))
}

/// Shutdown the internel message receiver and wait for its thread to exit.
Expand Down Expand Up @@ -795,15 +805,60 @@ impl PubsubClient {
T: DeserializeOwned,
F: Fn(T) + Send + 'static,
{
let ping_interval: Duration = Duration::from_secs(DEFAULT_PING_DURATION_SECONDS);
let max_failed_pings: usize = DEFAULT_MAX_FAILED_PINGS;
let mut last_ping_time: std::time::Instant = std::time::Instant::now();
let elapsed_pings: Arc<AtomicUsize> = Arc::new(AtomicUsize::new(0));

loop {
if exit.load(Ordering::Relaxed) {
break;
}

// Send ping if the interval has passed
if last_ping_time.elapsed() >= ping_interval {
if let Err(err) = socket.write().unwrap().send(Message::Ping(vec![])) {
info!("Error sending ping: {:?}", err);
break;
}

last_ping_time = std::time::Instant::now();
let pings = elapsed_pings.fetch_add(1, Ordering::Relaxed) + 1;

// Check if max_failed_pings has been exceeded
if pings > max_failed_pings {
info!(
"No pong received after {} pings. Closing connection...",
max_failed_pings
);

let _ = socket.write().unwrap().close(None);
break;
}
}

let mut ws = socket.write().unwrap();
let maybe_tls_stream = ws.get_mut();

// We can only set a read time out safely if it's a plain TCP connection
if let MaybeTlsStream::Plain(tcp_stream) = maybe_tls_stream {
if let Err(e) = tcp_stream.set_read_timeout(Some(Duration::from_millis(500))) {
info!("Failed to set read timeout on TcpStream: {:?}", e);
}
}

match PubsubClientSubscription::read_message(socket) {
Ok(Some(message)) => handler(message),
Ok(Some(message)) => {
elapsed_pings.store(0, Ordering::Relaxed);
handler(message)
}
Ok(None) => {
// Nothing useful, means we received a ping message
elapsed_pings.store(0, Ordering::Relaxed);
}
Err(ref err) if err.is_timeout() => {
// Read timed out - continue the loop
continue;
}
Err(err) => {
info!("receive error: {:?}", err);
Expand Down
Loading