From 6fb6eb4934f5079e9d16782c7666d524610464d2 Mon Sep 17 00:00:00 2001 From: Alex Pyattaev Date: Fri, 17 Jan 2025 18:33:18 +0000 Subject: [PATCH] Switch some functions in net-utils to tokio --- Cargo.lock | 2 + net-utils/Cargo.toml | 2 + net-utils/src/lib.rs | 282 +++++++++++++++++++++++++--------------- programs/sbf/Cargo.lock | 2 + svm/examples/Cargo.lock | 2 + 5 files changed, 188 insertions(+), 102 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 906564f8316908..1d6378dbb14708 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7993,7 +7993,9 @@ dependencies = [ name = "solana-net-utils" version = "2.2.0" dependencies = [ + "anyhow", "bincode", + "bytes", "clap 3.2.23", "crossbeam-channel", "log", diff --git a/net-utils/Cargo.toml b/net-utils/Cargo.toml index f5468118a6150e..4c177d68e0f423 100644 --- a/net-utils/Cargo.toml +++ b/net-utils/Cargo.toml @@ -10,7 +10,9 @@ license = { workspace = true } edition = { workspace = true } [dependencies] +anyhow = { workspace = true } bincode = { workspace = true } +bytes = { workspace = true } clap = { version = "3.1.5", features = ["cargo"], optional = true } crossbeam-channel = { workspace = true } log = { workspace = true } diff --git a/net-utils/src/lib.rs b/net-utils/src/lib.rs index d4d89abf097686..94cb3e3ce3a6da 100644 --- a/net-utils/src/lib.rs +++ b/net-utils/src/lib.rs @@ -1,19 +1,26 @@ //! The `net_utils` module assists with networking #![allow(clippy::arithmetic_side_effects)] + #[cfg(feature = "dev-context-only-utils")] use tokio::net::UdpSocket as TokioUdpSocket; use { - crossbeam_channel::unbounded, + anyhow::{anyhow, bail}, + bytes::{BufMut, BytesMut}, log::*, rand::{thread_rng, Rng}, socket2::{Domain, SockAddr, Socket, Type}, std::{ collections::{BTreeMap, HashSet}, - io::{self, Read, Write}, + io::{self}, net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener, TcpStream, ToSocketAddrs, UdpSocket}, sync::{Arc, RwLock}, time::{Duration, Instant}, }, + tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::TcpSocket, + sync::oneshot, + }, url::Url, }; @@ -39,93 +46,128 @@ pub const MINIMUM_VALIDATOR_PORT_RANGE_WIDTH: u16 = 17; // VALIDATOR_PORT_RANGE pub(crate) const HEADER_LENGTH: usize = 4; pub(crate) const IP_ECHO_SERVER_RESPONSE_LENGTH: usize = HEADER_LENGTH + 23; -fn ip_echo_server_request( - ip_echo_server_addr: &SocketAddr, +async fn ip_echo_server_request( + ip_echo_server_addr: SocketAddr, msg: IpEchoServerMessage, -) -> Result { + bind_address: Option, +) -> anyhow::Result { let timeout = Duration::new(5, 0); - TcpStream::connect_timeout(ip_echo_server_addr, timeout) - .and_then(|mut stream| { - // Start with HEADER_LENGTH null bytes to avoid looking like an HTTP GET/POST request - let mut bytes = vec![0; HEADER_LENGTH]; - - bytes.append(&mut bincode::serialize(&msg).expect("serialize IpEchoServerMessage")); - - // End with '\n' to make this request look HTTP-ish and tickle an error response back - // from an HTTP server - bytes.push(b'\n'); - - stream.set_read_timeout(Some(Duration::new(10, 0)))?; - stream.write_all(&bytes)?; - stream.shutdown(std::net::Shutdown::Write)?; - let mut data = vec![0u8; IP_ECHO_SERVER_RESPONSE_LENGTH]; - let _ = stream.read(&mut data[..])?; - Ok(data) - }) - .and_then(|data| { - // It's common for users to accidentally confuse the validator's gossip port and JSON - // RPC port. Attempt to detect when this occurs by looking for the standard HTTP - // response header and provide the user with a helpful error message - if data.len() < HEADER_LENGTH { - return Err(io::Error::new( - io::ErrorKind::Other, - format!("Response too short, received {} bytes", data.len()), - )); - } + let socket = tokio::net::TcpSocket::new_v4()?; + if let Some(addr) = bind_address { + socket.bind(SocketAddr::new(addr, 0))?; + } - let response_header: String = - data[0..HEADER_LENGTH].iter().map(|b| *b as char).collect(); - if response_header != "\0\0\0\0" { - if response_header == "HTTP" { - let http_response = data.iter().map(|b| *b as char).collect::(); - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Invalid gossip entrypoint. {ip_echo_server_addr} looks to be an HTTP port: {http_response}" - ), - )); - } - return Err(io::Error::new( - io::ErrorKind::Other, - format!( - "Invalid gossip entrypoint. {ip_echo_server_addr} provided an invalid response header: '{response_header}'" - ), - )); + async fn do_make_request( + socket: TcpSocket, + ip_echo_server_addr: SocketAddr, + msg: IpEchoServerMessage, + ) -> anyhow::Result { + let mut stream = socket.connect(ip_echo_server_addr).await?; + // Start with HEADER_LENGTH null bytes to avoid looking like an HTTP GET/POST request + let mut bytes = BytesMut::with_capacity(IP_ECHO_SERVER_RESPONSE_LENGTH); + bytes.extend_from_slice(&[0u8; HEADER_LENGTH]); + bytes.extend_from_slice(&bincode::serialize(&msg)?); + + // End with '\n' to make this request look HTTP-ish and tickle an error response back + // from an HTTP server + bytes.put_u8(b'\n'); + stream.write_all(&bytes).await?; + stream.flush().await?; + + bytes.clear(); + let _n = stream.read_buf(&mut bytes).await?; + stream.shutdown().await?; + + Ok(bytes) + } + + let response = + tokio::time::timeout(timeout, do_make_request(socket, ip_echo_server_addr, msg)).await??; + // It's common for users to accidentally confuse the validator's gossip port and JSON + // RPC port. Attempt to detect when this occurs by looking for the standard HTTP + // response header and provide the user with a helpful error message + if response.len() < HEADER_LENGTH { + bail!("Response too short, received {} bytes", response.len()); + } + + let (response_header, body) = + response + .split_first_chunk::() + .ok_or(anyhow::anyhow!( + "Not enough data in the response from {ip_echo_server_addr}!" + ))?; + let payload = match response_header { + [0, 0, 0, 0] => bincode::deserialize(&response[HEADER_LENGTH..])?, + [b'H', b'T', b'T', b'P'] => { + let http_response = std::str::from_utf8(body); + match http_response { + Ok(r) => bail!("Invalid gossip entrypoint. {ip_echo_server_addr} looks to be an HTTP port replying with {r}"), + Err(_) => bail!("Invalid gossip entrypoint. {ip_echo_server_addr} looks to be an HTTP port."), } + } + _ => { + bail!("Invalid gossip entrypoint. {ip_echo_server_addr} provided unexpected header bytes {response_header:?} "); + } + }; - bincode::deserialize(&data[HEADER_LENGTH..]).map_err(|err| { - io::Error::new( - io::ErrorKind::Other, - format!("Failed to deserialize: {err:?}"), - ) - }) - }) - .map_err(|err| err.to_string()) + Ok(payload) } /// Determine the public IP address of this machine by asking an ip_echo_server at the given /// address pub fn get_public_ip_addr(ip_echo_server_addr: &SocketAddr) -> Result { - let resp = ip_echo_server_request(ip_echo_server_addr, IpEchoServerMessage::default())?; + get_public_ip_addr_with_binding(ip_echo_server_addr, None).map_err(|e| e.to_string()) +} + +/// Determine the public IP address of this machine by asking an ip_echo_server at the given +/// address +pub fn get_public_ip_addr_with_binding( + ip_echo_server_addr: &SocketAddr, + bind_address: Option, +) -> anyhow::Result { + let fut = ip_echo_server_request( + *ip_echo_server_addr, + IpEchoServerMessage::default(), + bind_address, + ); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let resp = rt.block_on(fut)?; Ok(resp.address) } pub fn get_cluster_shred_version(ip_echo_server_addr: &SocketAddr) -> Result { - let resp = ip_echo_server_request(ip_echo_server_addr, IpEchoServerMessage::default())?; - resp.shred_version - .ok_or_else(|| String::from("IP echo server does not return a shred-version")) + get_cluster_shred_version_with_binding(ip_echo_server_addr, None) + .map_err(|_| String::from("IP echo server does not return a shred-version")) } +pub fn get_cluster_shred_version_with_binding( + ip_echo_server_addr: &SocketAddr, + bind_address: Option, +) -> anyhow::Result { + let fut = ip_echo_server_request( + *ip_echo_server_addr, + IpEchoServerMessage::default(), + bind_address, + ); + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let resp = rt.block_on(fut)?; + resp.shred_version + .ok_or_else(|| anyhow!("IP echo server does not return a shred-version")) +} // Checks if any of the provided TCP/UDP ports are not reachable by the machine at // `ip_echo_server_addr` -const DEFAULT_TIMEOUT_SECS: u64 = 5; +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5); const DEFAULT_RETRY_COUNT: usize = 5; -fn do_verify_reachable_ports( - ip_echo_server_addr: &SocketAddr, +async fn do_verify_reachable_ports( + ip_echo_server_addr: SocketAddr, tcp_listeners: Vec<(u16, TcpListener)>, udp_sockets: &[&UdpSocket], - timeout: u64, + timeout: Duration, udp_retry_count: usize, ) -> bool { info!( @@ -137,36 +179,50 @@ fn do_verify_reachable_ports( let _ = ip_echo_server_request( ip_echo_server_addr, IpEchoServerMessage::new(&tcp_ports, &[]), + None, ) + .await .map_err(|err| warn!("ip_echo_server request failed: {}", err)); let mut ok = true; - let timeout = Duration::from_secs(timeout); + let mut checkers = Vec::new(); - // Wait for a connection to open on each TCP port + // since we do not know if tcp_listeners are nonblocking, we have to run them in native threads. for (port, tcp_listener) in tcp_listeners { - let (sender, receiver) = unbounded(); let listening_addr = tcp_listener.local_addr().unwrap(); + let (sender, receiver) = oneshot::channel(); let thread_handle = std::thread::Builder::new() .name(format!("solVrfyTcp{port:05}")) .spawn(move || { debug!("Waiting for incoming connection on tcp/{}", port); match tcp_listener.incoming().next() { - Some(_) => sender - .send(()) - .unwrap_or_else(|err| warn!("send failure: {}", err)), + Some(_) => { + // ignore errors here since this can only happen if a timeout was detected. + // timeout drops the receiver part of the channel resulting in failure to send. + let _ = sender.send(()); + } None => warn!("tcp incoming failed"), } }) .unwrap(); - match receiver.recv_timeout(timeout) { - Ok(_) => { - info!("tcp/{} is reachable", port); + + // Set the timeout on the receiver + let receiver = tokio::time::timeout(timeout, receiver); + checkers.push((listening_addr, thread_handle, receiver)); + } + + for (listening_addr, thread_handle, receiver) in checkers { + match receiver.await { + Ok(Ok(_)) => { + info!("tcp/{} is reachable", listening_addr.port()); } - Err(err) => { + Ok(Err(_v)) => { + unreachable!("The receive on oneshot channel should never fail"); + } + Err(_t) => { error!( - "Received no response at tcp/{}, check your port configuration: {}", - port, err + "Received no response at tcp/{}, check your port configuration", + listening_addr.port() ); // Ugh, std rustc doesn't provide accepting with timeout or restoring original // nonblocking-status of sockets because of lack of getter, only the setter... @@ -176,15 +232,16 @@ fn do_verify_reachable_ports( ok = false; } } - // ensure to reap the thread - thread_handle.join().unwrap(); + thread_handle.join().expect("Thread should exit cleanly") } if !ok { - // No retries for TCP, abort on the first failure - return ok; + // No retries for TCP, abort on any failure + return false; } + // now check UDP ports + let mut ok = true; let mut udp_ports: BTreeMap<_, _> = BTreeMap::new(); udp_sockets.iter().for_each(|udp_socket| { let port = udp_socket.local_addr().unwrap().port(); @@ -218,7 +275,9 @@ fn do_verify_reachable_ports( let _ = ip_echo_server_request( ip_echo_server_addr, IpEchoServerMessage::new(&[], &checked_ports), + None, ) + .await .map_err(|err| warn!("ip_echo_server request failed: {}", err)); // Spawn threads at once! @@ -300,13 +359,18 @@ pub fn verify_reachable_ports( tcp_listeners: Vec<(u16, TcpListener)>, udp_sockets: &[&UdpSocket], ) -> bool { - do_verify_reachable_ports( - ip_echo_server_addr, + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Tokio builder should be able to reliably create a current thread runtime"); + let fut = do_verify_reachable_ports( + *ip_echo_server_addr, tcp_listeners, udp_sockets, - DEFAULT_TIMEOUT_SECS, + DEFAULT_TIMEOUT, DEFAULT_RETRY_COUNT, - ) + ); + rt.block_on(fut) } pub fn parse_port_or_addr(optstr: Option<&str>, default_addr: SocketAddr) -> SocketAddr { @@ -780,8 +844,14 @@ pub fn bind_more_with_config( #[cfg(test)] mod tests { - use {super::*, std::net::Ipv4Addr}; + use {super::*, std::net::Ipv4Addr, tokio::runtime::Runtime}; + fn runtime() -> Runtime { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Can not create a runtime") + } #[test] fn test_response_length() { let resp = IpEchoServerResponse { @@ -957,10 +1027,13 @@ mod tests { let server_ip_echo_addr = server_udp_socket.local_addr().unwrap(); assert_eq!( - get_public_ip_addr(&server_ip_echo_addr), - parse_host("127.0.0.1"), + get_public_ip_addr_with_binding(&server_ip_echo_addr, None).unwrap(), + parse_host("127.0.0.1").unwrap(), + ); + assert_eq!( + get_cluster_shred_version_with_binding(&server_ip_echo_addr, None).unwrap(), + 42 ); - assert_eq!(get_cluster_shred_version(&server_ip_echo_addr), Ok(42)); assert!(verify_reachable_ports(&server_ip_echo_addr, vec![], &[],)); } @@ -982,10 +1055,13 @@ mod tests { let ip_echo_server_addr = server_udp_socket.local_addr().unwrap(); assert_eq!( - get_public_ip_addr(&ip_echo_server_addr), - parse_host("127.0.0.1"), + get_public_ip_addr_with_binding(&ip_echo_server_addr, None).unwrap(), + parse_host("127.0.0.1").unwrap(), + ); + assert_eq!( + get_cluster_shred_version_with_binding(&ip_echo_server_addr, None).unwrap(), + 65535 ); - assert_eq!(get_cluster_shred_version(&ip_echo_server_addr), Ok(65535)); assert!(verify_reachable_ports( &ip_echo_server_addr, vec![(client_port, client_tcp_listener)], @@ -1008,13 +1084,14 @@ mod tests { let (correct_client_port, (_client_udp_socket, client_tcp_listener)) = bind_common_in_range_with_config(ip_addr, (3200, 3250), config).unwrap(); - assert!(!do_verify_reachable_ports( - &server_ip_echo_addr, + let rt = runtime(); + assert!(!rt.block_on(do_verify_reachable_ports( + server_ip_echo_addr, vec![(correct_client_port, client_tcp_listener)], &[], - 2, + Duration::from_secs(2), 3, - )); + ))); } #[test] @@ -1032,13 +1109,14 @@ mod tests { let (_correct_client_port, (client_udp_socket, _client_tcp_listener)) = bind_common_in_range_with_config(ip_addr, (3200, 3250), config).unwrap(); - assert!(!do_verify_reachable_ports( - &server_ip_echo_addr, + let rt = runtime(); + assert!(!rt.block_on(do_verify_reachable_ports( + server_ip_echo_addr, vec![], &[&client_udp_socket], - 2, + Duration::from_secs(2), 3, - )); + ))); } #[test] diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index de95dd72e0d5d4..870f7eb5e18d3d 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -6323,7 +6323,9 @@ version = "2.2.0" name = "solana-net-utils" version = "2.2.0" dependencies = [ + "anyhow", "bincode", + "bytes", "crossbeam-channel", "log", "nix", diff --git a/svm/examples/Cargo.lock b/svm/examples/Cargo.lock index 99d7d319e985cf..0c5c0d989ba1c1 100644 --- a/svm/examples/Cargo.lock +++ b/svm/examples/Cargo.lock @@ -6153,7 +6153,9 @@ version = "2.2.0" name = "solana-net-utils" version = "2.2.0" dependencies = [ + "anyhow", "bincode", + "bytes", "crossbeam-channel", "log", "nix",