diff --git a/src/client.rs b/src/client.rs index 51f3854b..6f317266 100644 --- a/src/client.rs +++ b/src/client.rs @@ -266,39 +266,46 @@ impl Client { self.http_version == http::Version::HTTP_2 } - #[cfg(unix)] async fn client( &self, addr: (std::net::IpAddr, u16), url: &Url, ) -> Result { - if url.scheme() == "https" { - self.tls_client(addr, url).await - } else if let Some(socket_path) = &self.unix_socket { - Ok(Stream::Unix( - tokio::net::UnixStream::connect(socket_path).await?, - )) - } else { - let stream = tokio::net::TcpStream::connect(addr).await?; - stream.set_nodelay(true)?; - // stream.set_keepalive(std::time::Duration::from_secs(1).into())?; - Ok(Stream::Tcp(stream)) - } - } + // TODO: Allow the connect timeout to be configured + let timeout_duration = tokio::time::Duration::from_secs(5); - #[cfg(not(unix))] - async fn client( - &self, - addr: (std::net::IpAddr, u16), - url: &Url, - ) -> Result { if url.scheme() == "https" { - self.tls_client(addr, url).await - } else { - let stream = tokio::net::TcpStream::connect(addr).await?; - stream.set_nodelay(true)?; - // stream.set_keepalive(std::time::Duration::from_secs(1).into())?; - Ok(Stream::Tcp(stream)) + // If we do not put a timeout here then the connections attempts will + // linger long past the configured timeout + let stream = tokio::time::timeout(timeout_duration, self.tls_client(addr, url)).await; + return match stream { + Ok(Ok(stream)) => Ok(stream), + Ok(Err(err)) => Err(err), + Err(_) => Err(ClientError::Timeout), + }; + } + #[cfg(unix)] + if let Some(socket_path) = &self.unix_socket { + let stream = tokio::time::timeout( + timeout_duration, + tokio::net::UnixStream::connect(socket_path), + ) + .await; + return match stream { + Ok(Ok(stream)) => Ok(Stream::Unix(stream)), + Ok(Err(err)) => Err(ClientError::IoError(err)), + Err(_) => Err(ClientError::Timeout), + }; + } + let stream = + tokio::time::timeout(timeout_duration, tokio::net::TcpStream::connect(addr)).await; + match stream { + Ok(Ok(stream)) => { + stream.set_nodelay(true)?; + Ok(Stream::Tcp(stream)) + } + Ok(Err(err)) => Err(ClientError::IoError(err)), + Err(_) => Err(ClientError::Timeout), } } @@ -423,6 +430,8 @@ impl Client { .await .is_err() { + // This gets hit when the connection for HTTP/1.1 faults + // This re-connects start = std::time::Instant::now(); let addr = self.dns.lookup(&url, &mut client_state.rng).await?; let dns_lookup = std::time::Instant::now(); @@ -698,6 +707,20 @@ fn is_too_many_open_files(res: &Result) -> bool { .unwrap_or(false) } +/// Check error was any Hyper error (primarily for HTTP2 connection errors) +fn is_hyper_error(res: &Result) -> bool { + res.as_ref() + .err() + .map(|err| match err { + // REVIEW: IoErrors, if indicating the underlying connection has failed, + // should also cause a stop of HTTP2 requests + ClientError::IoError(_) => true, + ClientError::HyperError(_) => true, + _ => false, + }) + .unwrap_or(false) +} + async fn setup_http2(client: &Client) -> Result<(ConnectionTime, ClientStateHttp2), ClientError> { let mut rng = StdRng::from_entropy(); let url = client.url_generator.generate(&mut rng)?; @@ -1048,35 +1071,50 @@ pub async fn work_until( let client = client.clone(); let report_tx = report_tx.clone(); tokio::spawn(async move { - match setup_http2(&client).await { - Ok((connection_time, client_state)) => { - let futures = (0..n_http2_parallel) - .map(|_| { - let client = client.clone(); - let report_tx = report_tx.clone(); - let mut client_state = client_state.clone(); - tokio::spawn(async move { - loop { - let mut res = - client.work_http2(&mut client_state).await; - let is_cancel = is_too_many_open_files(&res); - set_connection_time(&mut res, connection_time); - report_tx.send_async(res).await.unwrap(); - if is_cancel { - break; + // Keep trying to establish or re-establish connections up to the deadline + loop { + match setup_http2(&client).await { + Ok((connection_time, client_state)) => { + // Setup the parallel workers for each HTTP2 connection + let futures = (0..n_http2_parallel) + .map(|_| { + let client = client.clone(); + let report_tx = report_tx.clone(); + let mut client_state = client_state.clone(); + tokio::spawn(async move { + // This is where HTTP2 loops to make all the requests for a given client and worker + loop { + let mut res = + client.work_http2(&mut client_state).await; + let is_cancel = is_too_many_open_files(&res); + let is_hyper_error = is_hyper_error(&res); + set_connection_time(&mut res, connection_time); + report_tx.send_async(res).await.unwrap(); + if is_cancel || is_hyper_error { + break is_cancel; + } } - } + }) }) - }) - .collect::>(); - - tokio::time::sleep_until(dead_line.into()).await; - for f in futures { - f.abort(); + .chain(std::iter::once(tokio::spawn(async move { + tokio::time::sleep_until(dead_line.into()).await; + true + }))) + .collect::>(); + + let (is_cancel, _, rest) = + futures::future::select_all(futures).await; + for f in rest { + f.abort(); + } + + if matches!(is_cancel, Ok(true)) { + break; + } } - } - Err(err) => report_tx.send_async(Err(err)).await.unwrap(), + Err(err) => report_tx.send_async(Err(err)).await.unwrap(), + } } }) }) @@ -1092,6 +1130,7 @@ pub async fn work_until( let mut client_state = ClientStateHttp1::default(); tokio::spawn(async move { loop { + // This is where HTTP1 loops to make all the requests for a given client let res = client.work_http1(&mut client_state).await; let is_cancel = is_too_many_open_files(&res); report_tx.send_async(res).await.unwrap();