From cc9e4e6e883777dcd428265c22bdbd6cdb8e5660 Mon Sep 17 00:00:00 2001 From: Friedel Ziegelmayer Date: Tue, 26 Nov 2024 18:20:15 +0100 Subject: [PATCH] feat(iroh-net): allow the underlying UdpSockets to be rebound (#2946) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description In order to handle supsension and exits on mobile. we need to rebind our UDP sockets when they break. This PR adds the ability to rebind the socket on errors, and does so automatically on known suspension errors for iOS. When reviewing this, please specifically look at the duration of lock holding, as this is the most sensitive part in this code. Some references for these errors - https://github.com/libevent/libevent/pull/1031 - https://github.com/n0-computer/iroh/issues/2939 ### TODOs - [x] code cleanup - [x] testing on actual ios apps, to see if this actually fixes the issues - [ ] potentially handle port still being in use? this needs some more thoughts Closes #2939 ## Breaking Changes The overall API for `netmon::UdpSocket` has changed entirely, everything else is the same. ## Notes & open questions - I have tried putting this logic higher in the stack, but unfortunately that did not work out. - We might not want to infinitely rebind a socket if the same error happens over and over again, unclear how to handle this. ## Change checklist - [ ] Self-review. - [ ] Documentation updates following the [style guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text), if relevant. - [ ] Tests if relevant. - [ ] All breaking changes documented. --------- Co-authored-by: Philipp Krüger --- Cargo.lock | 2 + iroh-net-report/src/reportgen/hairpin.rs | 2 +- iroh-net/src/magicsock.rs | 90 ++- iroh-net/src/magicsock/udp_conn.rs | 114 +-- net-tools/netwatch/Cargo.toml | 33 +- net-tools/netwatch/src/udp.rs | 887 +++++++++++++++++++++-- net-tools/portmapper/src/nat_pmp.rs | 6 +- net-tools/portmapper/src/pcp.rs | 6 +- 8 files changed, 953 insertions(+), 187 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5378414353..a511cede31 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3324,11 +3324,13 @@ name = "netwatch" version = "0.1.0" dependencies = [ "anyhow", + "atomic-waker", "bytes", "derive_more", "futures-lite 2.5.0", "futures-sink", "futures-util", + "iroh-quinn-udp", "libc", "netdev", "netlink-packet-core", diff --git a/iroh-net-report/src/reportgen/hairpin.rs b/iroh-net-report/src/reportgen/hairpin.rs index dc730a7c9a..17fd49e4f5 100644 --- a/iroh-net-report/src/reportgen/hairpin.rs +++ b/iroh-net-report/src/reportgen/hairpin.rs @@ -121,7 +121,7 @@ impl Actor { .context("net_report actor gone")?; msg_response_rx.await.context("net_report actor died")?; - if let Err(err) = socket.send_to(&stun::request(txn), dst).await { + if let Err(err) = socket.send_to(&stun::request(txn), dst.into()).await { warn!(%dst, "failed to send hairpin check"); return Err(err.into()); } diff --git a/iroh-net/src/magicsock.rs b/iroh-net/src/magicsock.rs index f4870f8377..87964a3108 100644 --- a/iroh-net/src/magicsock.rs +++ b/iroh-net/src/magicsock.rs @@ -36,7 +36,7 @@ use futures_util::stream::BoxStream; use iroh_base::key::NodeId; use iroh_metrics::{inc, inc_by}; use iroh_relay::protos::stun; -use netwatch::{interfaces, ip::LocalAddresses, netmon}; +use netwatch::{interfaces, ip::LocalAddresses, netmon, UdpSocket}; use quinn::AsyncUdpSocket; use rand::{seq::SliceRandom, Rng, SeedableRng}; use smallvec::{smallvec, SmallVec}; @@ -441,11 +441,8 @@ impl MagicSock { // Right now however we have one single poller behaving the same for each // connection. It checks all paths and returns Poll::Ready as soon as any path is // ready. - let ipv4_poller = Arc::new(self.pconn4.clone()).create_io_poller(); - let ipv6_poller = self - .pconn6 - .as_ref() - .map(|sock| Arc::new(sock.clone()).create_io_poller()); + let ipv4_poller = self.pconn4.create_io_poller(); + let ipv6_poller = self.pconn6.as_ref().map(|sock| sock.create_io_poller()); let relay_sender = self.relay_actor_sender.clone(); Box::pin(IoPoller { ipv4_poller, @@ -1091,10 +1088,9 @@ impl MagicSock { Err(err) if err.kind() == io::ErrorKind::WouldBlock => { // This is the socket .try_send_disco_message_udp used. let sock = self.conn_for_addr(dst)?; - let sock = Arc::new(sock.clone()); - let mut poller = sock.create_io_poller(); - match poller.as_mut().poll_writable(cx)? { - Poll::Ready(()) => continue, + match sock.as_socket_ref().poll_writable(cx) { + Poll::Ready(Ok(())) => continue, + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), Poll::Pending => return Poll::Pending, } } @@ -1408,6 +1404,9 @@ impl Handle { let net_reporter = net_report::Client::new(Some(port_mapper.clone()), dns_resolver.clone())?; + let pconn4_sock = pconn4.as_socket(); + let pconn6_sock = pconn6.as_ref().map(|p| p.as_socket()); + let (actor_sender, actor_receiver) = mpsc::channel(256); let (relay_actor_sender, relay_actor_receiver) = mpsc::channel(256); let (udp_disco_sender, mut udp_disco_receiver) = mpsc::channel(256); @@ -1431,9 +1430,9 @@ impl Handle { ipv6_reported: Arc::new(AtomicBool::new(false)), relay_map, my_relay: Default::default(), - pconn4: pconn4.clone(), - pconn6: pconn6.clone(), net_reporter: net_reporter.addr(), + pconn4, + pconn6, disco_secrets: DiscoSecrets::default(), node_map, relay_actor_sender: relay_actor_sender.clone(), @@ -1481,8 +1480,8 @@ impl Handle { periodic_re_stun_timer: new_re_stun_timer(false), net_info_last: None, port_mapper, - pconn4, - pconn6, + pconn4: pconn4_sock, + pconn6: pconn6_sock, no_v4_send: false, net_reporter, network_monitor, @@ -1720,8 +1719,8 @@ struct Actor { net_info_last: Option, // The underlying UDP sockets used to send/rcv packets. - pconn4: UdpConn, - pconn6: Option, + pconn4: Arc, + pconn6: Option>, /// The NAT-PMP/PCP/UPnP prober/client, for requesting port mappings from NAT devices. port_mapper: portmapper::Client, @@ -1861,6 +1860,14 @@ impl Actor { debug!("link change detected: major? {}", is_major); if is_major { + if let Err(err) = self.pconn4.rebind() { + warn!("failed to rebind Udp IPv4 socket: {:?}", err); + }; + if let Some(ref pconn6) = self.pconn6 { + if let Err(err) = pconn6.rebind() { + warn!("failed to rebind Udp IPv6 socket: {:?}", err); + }; + } self.msock.dns_resolver.clear_cache(); self.msock.re_stun("link-change-major"); self.close_stale_relay_connections().await; @@ -1893,14 +1900,6 @@ impl Actor { self.port_mapper.deactivate(); self.relay_actor_cancel_token.cancel(); - // Ignore errors from pconnN - // They will frequently have been closed already by a call to connBind.Close. - debug!("stopping connections"); - if let Some(ref conn) = self.pconn6 { - conn.close().await.ok(); - } - self.pconn4.close().await.ok(); - debug!("shutdown complete"); return true; } @@ -2206,8 +2205,8 @@ impl Actor { } let relay_map = self.msock.relay_map.clone(); - let pconn4 = Some(self.pconn4.as_socket()); - let pconn6 = self.pconn6.as_ref().map(|p| p.as_socket()); + let pconn4 = Some(self.pconn4.clone()); + let pconn6 = self.pconn6.clone(); debug!("requesting net_report report"); match self @@ -3099,6 +3098,45 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_regression_network_change_rebind_wakes_connection_driver( + ) -> testresult::TestResult { + let _ = iroh_test::logging::setup(); + let m1 = MagicStack::new(RelayMode::Disabled).await?; + let m2 = MagicStack::new(RelayMode::Disabled).await?; + + println!("Net change"); + m1.endpoint.magic_sock().force_network_change(true).await; + tokio::time::sleep(Duration::from_secs(1)).await; // wait for socket rebinding + + let _guard = mesh_stacks(vec![m1.clone(), m2.clone()]).await?; + + let _handle = AbortOnDropHandle::new(tokio::spawn({ + let endpoint = m2.endpoint.clone(); + async move { + while let Some(incoming) = endpoint.accept().await { + println!("Incoming first conn!"); + let conn = incoming.await?; + conn.closed().await; + } + + testresult::TestResult::Ok(()) + } + })); + + println!("first conn!"); + let conn = m1 + .endpoint + .connect(m2.endpoint.node_addr().await?, ALPN) + .await?; + println!("Closing first conn"); + conn.close(0u32.into(), b"bye lolz"); + conn.closed().await; + println!("Closed first conn"); + + Ok(()) + } + #[tokio::test(flavor = "multi_thread")] async fn test_two_devices_roundtrip_network_change() -> Result<()> { time::timeout( diff --git a/iroh-net/src/magicsock/udp_conn.rs b/iroh-net/src/magicsock/udp_conn.rs index 2c23d44f5b..8626c3fcec 100644 --- a/iroh-net/src/magicsock/udp_conn.rs +++ b/iroh-net/src/magicsock/udp_conn.rs @@ -1,25 +1,22 @@ use std::{ fmt::Debug, - future::Future, io, net::SocketAddr, pin::Pin, sync::Arc, - task::{ready, Context, Poll}, + task::{Context, Poll}, }; use anyhow::{bail, Context as _}; use netwatch::UdpSocket; use quinn::AsyncUdpSocket; -use quinn_udp::{Transmit, UdpSockRef}; -use tokio::io::Interest; -use tracing::{debug, trace}; +use quinn_udp::Transmit; +use tracing::debug; /// A UDP socket implementing Quinn's [`AsyncUdpSocket`]. -#[derive(Clone, Debug)] +#[derive(Debug, Clone)] pub struct UdpConn { io: Arc, - inner: Arc, } impl UdpConn { @@ -27,43 +24,34 @@ impl UdpConn { self.io.clone() } + pub(super) fn as_socket_ref(&self) -> &UdpSocket { + &self.io + } + pub(super) fn bind(addr: SocketAddr) -> anyhow::Result { let sock = bind(addr)?; - let state = quinn_udp::UdpSocketState::new(quinn_udp::UdpSockRef::from(&sock))?; - Ok(Self { - io: Arc::new(sock), - inner: Arc::new(state), - }) + + Ok(Self { io: Arc::new(sock) }) } pub fn port(&self) -> u16 { self.local_addr().map(|p| p.port()).unwrap_or_default() } - #[allow(clippy::unused_async)] - pub async fn close(&self) -> Result<(), io::Error> { - // Nothing to do atm - Ok(()) + pub(super) fn create_io_poller(&self) -> Pin> { + Box::pin(IoPoller { + io: self.io.clone(), + }) } } impl AsyncUdpSocket for UdpConn { fn create_io_poller(self: Arc) -> Pin> { - let sock = self.io.clone(); - Box::pin(IoPoller { - next_waiter: move || { - let sock = sock.clone(); - async move { sock.writable().await } - }, - waiter: None, - }) + (*self).create_io_poller() } fn try_send(&self, transmit: &Transmit<'_>) -> io::Result<()> { - self.io.try_io(Interest::WRITABLE, || { - let sock_ref = UdpSockRef::from(&self.io); - self.inner.send(sock_ref, transmit) - }) + self.io.try_send_quinn(transmit) } fn poll_recv( @@ -72,24 +60,7 @@ impl AsyncUdpSocket for UdpConn { bufs: &mut [io::IoSliceMut<'_>], meta: &mut [quinn_udp::RecvMeta], ) -> Poll> { - loop { - ready!(self.io.poll_recv_ready(cx))?; - if let Ok(res) = self.io.try_io(Interest::READABLE, || { - self.inner.recv(Arc::as_ref(&self.io).into(), bufs, meta) - }) { - for meta in meta.iter().take(res) { - trace!( - src = %meta.addr, - len = meta.len, - count = meta.len / meta.stride, - dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), - "UDP recv" - ); - } - - return Poll::Ready(Ok(res)); - } - } + self.io.poll_recv_quinn(cx, bufs, meta) } fn local_addr(&self) -> io::Result { @@ -97,15 +68,15 @@ impl AsyncUdpSocket for UdpConn { } fn may_fragment(&self) -> bool { - self.inner.may_fragment() + self.io.may_fragment() } fn max_transmit_segments(&self) -> usize { - self.inner.max_gso_segments() + self.io.max_gso_segments() } fn max_receive_segments(&self) -> usize { - self.inner.gro_segments() + self.io.gro_segments() } } @@ -147,49 +118,14 @@ fn bind(mut addr: SocketAddr) -> anyhow::Result { } /// Poller for when the socket is writable. -/// -/// The tricky part is that we only have `tokio::net::UdpSocket::writable()` to create the -/// waiter we need, which does not return a named future type. In order to be able to store -/// this waiter in a struct without boxing we need to specify the future itself as a type -/// parameter, which we can only do if we introduce a second type parameter which returns -/// the future. So we end up with a function which we do not need, but it makes the types -/// work. -#[derive(derive_more::Debug)] -#[pin_project::pin_project] -struct IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ - /// Function which can create a new waiter if there is none. - #[debug("next_waiter")] - next_waiter: F, - /// The waiter which tells us when the socket is writable. - #[debug("waiter")] - #[pin] - waiter: Option, +#[derive(Debug)] +struct IoPoller { + io: Arc, } -impl quinn::UdpPoller for IoPoller -where - F: Fn() -> Fut + Send + Sync + 'static, - Fut: Future> + Send + Sync + 'static, -{ +impl quinn::UdpPoller for IoPoller { fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - let mut this = self.project(); - if this.waiter.is_none() { - this.waiter.set(Some((this.next_waiter)())); - } - let result = this - .waiter - .as_mut() - .as_pin_mut() - .expect("just set") - .poll(cx); - if result.is_ready() { - this.waiter.set(None); - } - result + self.io.poll_writable(cx) } } diff --git a/net-tools/netwatch/Cargo.toml b/net-tools/netwatch/Cargo.toml index 38637d45b6..2a0050666d 100644 --- a/net-tools/netwatch/Cargo.toml +++ b/net-tools/netwatch/Cargo.toml @@ -14,6 +14,7 @@ workspace = true [dependencies] anyhow = { version = "1" } +atomic-waker = "1.1.2" bytes = "1.7" futures-lite = "2.3" futures-sink = "0.3.25" @@ -21,10 +22,22 @@ futures-util = "0.3.25" libc = "0.2.139" netdev = "0.30.0" once_cell = "1.18.0" +quinn-udp = { package = "iroh-quinn-udp", version = "0.5.5" } socket2 = "0.5.3" thiserror = "1" time = "0.3.20" -tokio = { version = "1", features = ["io-util", "macros", "sync", "rt", "net", "fs", "io-std", "signal", "process", "time"] } +tokio = { version = "1", features = [ + "io-util", + "macros", + "sync", + "rt", + "net", + "fs", + "io-std", + "signal", + "process", + "time", +] } tokio-util = { version = "0.7", features = ["rt"] } tracing = "0.1" @@ -36,12 +49,26 @@ rtnetlink = "0.13.0" [target.'cfg(target_os = "windows")'.dependencies] wmi = "0.13" -windows = { version = "0.51", features = ["Win32_NetworkManagement_IpHelper", "Win32_Foundation", "Win32_NetworkManagement_Ndis", "Win32_Networking_WinSock"] } +windows = { version = "0.51", features = [ + "Win32_NetworkManagement_IpHelper", + "Win32_Foundation", + "Win32_NetworkManagement_Ndis", + "Win32_Networking_WinSock", +] } serde = { version = "1", features = ["derive"] } derive_more = { version = "1.0.0", features = ["debug"] } [dev-dependencies] -tokio = { version = "1", features = ["io-util", "sync", "rt", "net", "fs", "macros", "time", "test-util"] } +tokio = { version = "1", features = [ + "io-util", + "sync", + "rt", + "net", + "fs", + "macros", + "time", + "test-util", +] } [package.metadata.docs.rs] all-features = true diff --git a/net-tools/netwatch/src/udp.rs b/net-tools/netwatch/src/udp.rs index 3aba36277f..ab9f130402 100644 --- a/net-tools/netwatch/src/udp.rs +++ b/net-tools/netwatch/src/udp.rs @@ -1,147 +1,910 @@ -use std::net::SocketAddr; +use std::{ + future::Future, + io, + net::SocketAddr, + pin::Pin, + sync::{atomic::AtomicBool, RwLock, RwLockReadGuard, TryLockError}, + task::{Context, Poll}, +}; -use anyhow::{ensure, Context, Result}; -use tracing::warn; +use atomic_waker::AtomicWaker; +use quinn_udp::Transmit; +use tokio::io::Interest; +use tracing::{debug, trace, warn}; use super::IpFamily; -/// Wrapper around a tokio UDP socket that handles the fact that -/// on drop `libc::close` can block for UDP sockets. +/// Wrapper around a tokio UDP socket. #[derive(Debug)] -pub struct UdpSocket(Option); +pub struct UdpSocket { + socket: RwLock, + recv_waker: AtomicWaker, + send_waker: AtomicWaker, + /// Set to true, when an error occurred, that means we need to rebind the socket. + is_broken: AtomicBool, +} /// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it /// is the max supported by a default configuration of macOS. Some platforms will silently clamp the value. const SOCKET_BUFFER_SIZE: usize = 7 << 20; impl UdpSocket { /// Bind only Ipv4 on any interface. - pub fn bind_v4(port: u16) -> Result { + pub fn bind_v4(port: u16) -> io::Result { Self::bind(IpFamily::V4, port) } /// Bind only Ipv6 on any interface. - pub fn bind_v6(port: u16) -> Result { + pub fn bind_v6(port: u16) -> io::Result { Self::bind(IpFamily::V6, port) } /// Bind only Ipv4 on localhost. - pub fn bind_local_v4(port: u16) -> Result { + pub fn bind_local_v4(port: u16) -> io::Result { Self::bind_local(IpFamily::V4, port) } /// Bind only Ipv6 on localhost. - pub fn bind_local_v6(port: u16) -> Result { + pub fn bind_local_v6(port: u16) -> io::Result { Self::bind_local(IpFamily::V6, port) } /// Bind to the given port only on localhost. - pub fn bind_local(network: IpFamily, port: u16) -> Result { + pub fn bind_local(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.local_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to the given port and listen on all interfaces. - pub fn bind(network: IpFamily, port: u16) -> Result { + pub fn bind(network: IpFamily, port: u16) -> io::Result { let addr = SocketAddr::new(network.unspecified_addr(), port); - Self::bind_raw(addr).with_context(|| format!("{addr:?}")) + Self::bind_raw(addr) } /// Bind to any provided [`SocketAddr`]. - pub fn bind_full(addr: impl Into) -> Result { + pub fn bind_full(addr: impl Into) -> io::Result { Self::bind_raw(addr) } - fn bind_raw(addr: impl Into) -> Result { - let addr = addr.into(); + /// Is the socket broken and needs a rebind? + pub fn is_broken(&self) -> bool { + self.is_broken.load(std::sync::atomic::Ordering::Acquire) + } + + /// Marks this socket as needing a rebind + fn mark_broken(&self) { + self.is_broken + .store(true, std::sync::atomic::Ordering::Release); + } + + /// Rebind the underlying socket. + pub fn rebind(&self) -> io::Result<()> { + { + let mut guard = self.socket.write().unwrap(); + guard.rebind()?; + + // Clear errors + self.is_broken + .store(false, std::sync::atomic::Ordering::Release); + + drop(guard); + } + + // wakeup + self.wake_all(); + + Ok(()) + } + + fn bind_raw(addr: impl Into) -> io::Result { + let socket = SocketState::bind(addr.into())?; + + Ok(UdpSocket { + socket: RwLock::new(socket), + recv_waker: AtomicWaker::default(), + send_waker: AtomicWaker::default(), + is_broken: AtomicBool::new(false), + }) + } + + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes read. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + pub fn recv<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFut<'a, 'b> { + RecvFut { + socket: self, + buffer, + } + } + + /// Receives a single datagram message on the socket. On success, returns + /// the number of bytes read and the origin. + /// + /// The function must be called with valid byte array `buf` of sufficient + /// size to hold the message bytes. If a message is too long to fit in the + /// supplied buffer, excess bytes may be discarded. + pub fn recv_from<'a, 'b>(&'b self, buffer: &'a mut [u8]) -> RecvFromFut<'a, 'b> { + RecvFromFut { + socket: self, + buffer, + } + } + + /// Sends data on the socket to the remote address that the socket is + /// connected to. + /// + /// The [`connect`] method will connect this socket to a remote address. + /// This method will fail if the socket is not connected. + /// + /// [`connect`]: method@Self::connect + /// + /// # Return + /// + /// On success, the number of bytes sent is returned, otherwise, the + /// encountered error is returned. + pub fn send<'a, 'b>(&'b self, buffer: &'a [u8]) -> SendFut<'a, 'b> { + SendFut { + socket: self, + buffer, + } + } + + /// Sends data on the socket to the given address. On success, returns the + /// number of bytes written. + pub fn send_to<'a, 'b>(&'b self, buffer: &'a [u8], to: SocketAddr) -> SendToFut<'a, 'b> { + SendToFut { + socket: self, + buffer, + to, + } + } + + /// Connects the UDP socket setting the default destination for send() and + /// limiting packets that are read via `recv` from the address specified in + /// `addr`. + pub fn connect(&self, addr: SocketAddr) -> io::Result<()> { + tracing::info!("connectnig to {}", addr); + let guard = self.socket.read().unwrap(); + let (socket_tokio, _state) = guard.try_get_connected()?; + + let sock_ref = socket2::SockRef::from(&socket_tokio); + sock_ref.connect(&socket2::SockAddr::from(addr))?; + + Ok(()) + } + + /// Returns the local address of this socket. + pub fn local_addr(&self) -> io::Result { + let guard = self.socket.read().unwrap(); + let (socket, _state) = guard.try_get_connected()?; + + socket.local_addr() + } + + /// Closes the socket, and waits for the underlying `libc::close` call to be finished. + pub async fn close(&self) { + let socket = self.socket.write().unwrap().close(); + self.wake_all(); + if let Some((sock, _)) = socket { + let std_sock = sock.into_std(); + let res = tokio::runtime::Handle::current() + .spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }) + .await; + if let Err(err) = res { + warn!("failed to close socket: {:?}", err); + } + } + } + + /// Check if this socket is closed. + pub fn is_closed(&self) -> bool { + self.socket.read().unwrap().is_closed() + } + + /// Handle potential read errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_read_error(&self, error: io::Error) -> Option { + match error.kind() { + io::ErrorKind::NotConnected => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + None + } + _ => Some(error), + } + } + + /// Handle potential write errors, updating internal state. + /// + /// Returns `Some(error)` if the error is fatal otherwise `None. + fn handle_write_error(&self, error: io::Error) -> Option { + match error.kind() { + io::ErrorKind::BrokenPipe => { + // This indicates the underlying socket is broken, and we should attempt to rebind it + self.mark_broken(); + None + } + _ => Some(error), + } + } + + /// Try to get a read lock for the sockets, but don't block for trying to acquire it. + fn poll_read_socket( + &self, + waker: &AtomicWaker, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + waker.register(cx.waker()); + + match self.socket.try_read() { + Ok(guard) => { + // we're actually fine, no need to cause a spurious wakeup + waker.take(); + guard + } + Err(TryLockError::Poisoned(e)) => panic!("socket lock poisoned: {e}"), + Err(TryLockError::WouldBlock) => { + // Ok fine, we registered our waker, the lock is really closed, + // we can return pending. + return Poll::Pending; + } + } + } + }; + Poll::Ready(guard) + } + + fn wake_all(&self) { + self.recv_waker.wake(); + self.send_waker.wake(); + } + + /// Checks if the socket needs a rebind, and if so does it. + /// + /// Returns an error if the rebind is needed, but failed. + fn maybe_rebind(&self) -> io::Result<()> { + if self.is_broken() { + self.rebind()?; + } + Ok(()) + } + + /// Poll for writable + pub fn poll_writable(&self, cx: &mut std::task::Context<'_>) -> Poll> { + loop { + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(self.poll_read_socket(&self.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => return Poll::Ready(Ok(())), + Poll::Ready(Err(err)) => { + if let Some(err) = self.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } + + /// Send a quinn based `Transmit`. + pub fn try_send_quinn(&self, transmit: &Transmit<'_>) -> io::Result<()> { + loop { + self.maybe_rebind()?; + + let guard = match self.socket.try_read() { + Ok(guard) => guard, + Err(TryLockError::Poisoned(e)) => { + panic!("lock poisoned: {:?}", e); + } + Err(TryLockError::WouldBlock) => { + return Err(io::Error::new(io::ErrorKind::WouldBlock, "")); + } + }; + let (socket, state) = guard.try_get_connected()?; + + let res = socket.try_io(Interest::WRITABLE, || state.send(socket.into(), transmit)); + + match res { + Ok(()) => return Ok(()), + Err(err) => match self.handle_write_error(err) { + Some(err) => return Err(err), + None => { + continue; + } + }, + } + } + } + + /// quinn based `poll_recv` + pub fn poll_recv_quinn( + &self, + cx: &mut Context, + bufs: &mut [io::IoSliceMut<'_>], + meta: &mut [quinn_udp::RecvMeta], + ) -> Poll> { + loop { + if let Err(err) = self.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(self.poll_read_socket(&self.recv_waker, cx)); + let (socket, state) = guard.try_get_connected()?; + + match socket.poll_recv_ready(cx) { + Poll::Pending => { + self.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + // We are ready to read, continue + } + Poll::Ready(Err(err)) => match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + }, + } + + let res = socket.try_io(Interest::READABLE, || state.recv(socket.into(), bufs, meta)); + match res { + Ok(count) => { + for meta in meta.iter().take(count) { + trace!( + src = %meta.addr, + len = meta.len, + count = meta.len / meta.stride, + dst = %meta.dst_ip.map(|x| x.to_string()).unwrap_or_default(), + "UDP recv" + ); + } + return Poll::Ready(Ok(count)); + } + Err(err) => { + // ignore spurious wakeups + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + match self.handle_read_error(err) { + Some(err) => return Poll::Ready(Err(err)), + None => { + continue; + } + } + } + } + } + } + + /// Whether transmitted datagrams might get fragmented by the IP layer + /// + /// Returns `false` on targets which employ e.g. the `IPV6_DONTFRAG` socket option. + pub fn may_fragment(&self) -> bool { + let guard = self.socket.read().unwrap(); + guard.may_fragment() + } + + /// The maximum amount of segments which can be transmitted if a platform + /// supports Generic Send Offload (GSO). + /// + /// This is 1 if the platform doesn't support GSO. Subject to change if errors are detected + /// while using GSO. + pub fn max_gso_segments(&self) -> usize { + let guard = self.socket.read().unwrap(); + guard.max_gso_segments() + } + + /// The number of segments to read when GRO is enabled. Used as a factor to + /// compute the receive buffer size. + /// + /// Returns 1 if the platform doesn't support GRO. + pub fn gro_segments(&self) -> usize { + let guard = self.socket.read().unwrap(); + guard.gro_segments() + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a mut [u8], +} + +impl Future for RecvFut<'_, '_> { + type Output = io::Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + + loop { + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); + let (inner_socket, _state) = guard.try_get_connected()?; + + match inner_socket.poll_recv_ready(cx) { + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = inner_socket.try_recv(buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Receive future +#[derive(Debug)] +pub struct RecvFromFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a mut [u8], +} + +impl Future for RecvFromFut<'_, '_> { + type Output = io::Result<(usize, SocketAddr)>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let Self { socket, buffer } = &mut *self; + + loop { + if let Err(err) = socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = futures_lite::ready!(socket.poll_read_socket(&socket.recv_waker, cx)); + let (inner_socket, _state) = guard.try_get_connected()?; + + match inner_socket.poll_recv_ready(cx) { + Poll::Pending => { + self.socket.recv_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = inner_socket.try_recv_from(buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = socket.handle_read_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Send future +#[derive(Debug)] +pub struct SendFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a [u8], +} + +impl Future for SendFut<'_, '_> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = socket.try_send(self.buffer); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +/// Send future +#[derive(Debug)] +pub struct SendToFut<'a, 'b> { + socket: &'b UdpSocket, + buffer: &'a [u8], + to: SocketAddr, +} + +impl Future for SendToFut<'_, '_> { + type Output = io::Result; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + loop { + if let Err(err) = self.socket.maybe_rebind() { + return Poll::Ready(Err(err)); + } + + let guard = + futures_lite::ready!(self.socket.poll_read_socket(&self.socket.send_waker, cx)); + let (socket, _state) = guard.try_get_connected()?; + + match socket.poll_send_ready(cx) { + Poll::Pending => { + self.socket.send_waker.register(cx.waker()); + return Poll::Pending; + } + Poll::Ready(Ok(())) => { + let res = socket.try_send_to(self.buffer, self.to); + if let Err(err) = res { + if err.kind() == io::ErrorKind::WouldBlock { + continue; + } + + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + return Poll::Ready(res); + } + Poll::Ready(Err(err)) => { + if let Some(err) = self.socket.handle_write_error(err) { + return Poll::Ready(Err(err)); + } + continue; + } + } + } + } +} + +#[derive(Debug)] +enum SocketState { + Connected { + socket: tokio::net::UdpSocket, + state: quinn_udp::UdpSocketState, + /// The addr we are binding to. + addr: SocketAddr, + }, + Closed { + last_max_gso_segments: usize, + last_gro_segments: usize, + last_may_fragment: bool, + }, +} + +impl SocketState { + fn try_get_connected( + &self, + ) -> io::Result<(&tokio::net::UdpSocket, &quinn_udp::UdpSocketState)> { + match self { + Self::Connected { + socket, + state, + addr: _, + } => Ok((socket, state)), + Self::Closed { .. } => { + warn!("socket closed"); + Err(io::Error::new(io::ErrorKind::BrokenPipe, "socket closed")) + } + } + } + + fn bind(addr: SocketAddr) -> io::Result { let network = IpFamily::from(addr.ip()); let socket = socket2::Socket::new( network.into(), socket2::Type::DGRAM, Some(socket2::Protocol::UDP), - ) - .context("socket create")?; + )?; if let Err(err) = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set recv_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); } if let Err(err) = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE) { - warn!( + debug!( "failed to set send_buffer_size to {}: {:?}", SOCKET_BUFFER_SIZE, err ); } if network == IpFamily::V6 { // Avoid dualstack - socket.set_only_v6(true).context("only IPv6")?; + socket.set_only_v6(true)?; } // Binding must happen before calling quinn, otherwise `local_addr` // is not yet available on all OSes. - socket.bind(&addr.into()).context("binding")?; + socket.bind(&addr.into())?; // Ensure nonblocking - socket.set_nonblocking(true).context("nonblocking: true")?; + socket.set_nonblocking(true)?; let socket: std::net::UdpSocket = socket.into(); // Convert into tokio UdpSocket - let socket = tokio::net::UdpSocket::from_std(socket).context("conversion to tokio")?; - - if addr.port() != 0 { - let local_addr = socket.local_addr().context("local addr")?; - ensure!( - local_addr.port() == addr.port(), - "wrong port bound: {:?}: wanted: {} got {}", - network, - addr.port(), - local_addr.port(), - ); + let socket = tokio::net::UdpSocket::from_std(socket)?; + let socket_ref = quinn_udp::UdpSockRef::from(&socket); + let socket_state = quinn_udp::UdpSocketState::new(socket_ref)?; + + let local_addr = socket.local_addr()?; + if addr.port() != 0 && local_addr.port() != addr.port() { + return Err(io::Error::new( + io::ErrorKind::Other, + format!( + "wrong port bound: {:?}: wanted: {} got {}", + network, + addr.port(), + local_addr.port(), + ), + )); } - Ok(UdpSocket(Some(socket))) + + Ok(Self::Connected { + socket, + state: socket_state, + addr: local_addr, + }) } -} -#[cfg(unix)] -impl std::os::fd::AsFd for UdpSocket { - fn as_fd(&self) -> std::os::fd::BorrowedFd<'_> { - self.0.as_ref().expect("not dropped").as_fd() + fn rebind(&mut self) -> io::Result<()> { + let (addr, closed_state) = match self { + Self::Connected { state, addr, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + (*addr, s) + } + Self::Closed { .. } => { + return Err(io::Error::new( + io::ErrorKind::Other, + "socket is closed and cannot be rebound", + )); + } + }; + debug!("rebinding {}", addr); + + *self = closed_state; + *self = Self::bind(addr)?; + + Ok(()) } -} -#[cfg(windows)] -impl std::os::windows::io::AsSocket for UdpSocket { - fn as_socket(&self) -> std::os::windows::io::BorrowedSocket<'_> { - self.0.as_ref().expect("not dropped").as_socket() + fn is_closed(&self) -> bool { + matches!(self, Self::Closed { .. }) + } + + fn close(&mut self) -> Option<(tokio::net::UdpSocket, quinn_udp::UdpSocketState)> { + match self { + Self::Connected { state, .. } => { + let s = SocketState::Closed { + last_max_gso_segments: state.max_gso_segments(), + last_gro_segments: state.gro_segments(), + last_may_fragment: state.may_fragment(), + }; + let Self::Connected { socket, state, .. } = std::mem::replace(self, s) else { + unreachable!("just checked"); + }; + Some((socket, state)) + } + Self::Closed { .. } => None, + } } -} -impl From for UdpSocket { - fn from(socket: tokio::net::UdpSocket) -> Self { - Self(Some(socket)) + fn may_fragment(&self) -> bool { + match self { + Self::Connected { state, .. } => state.may_fragment(), + Self::Closed { + last_may_fragment, .. + } => *last_may_fragment, + } } -} -impl std::ops::Deref for UdpSocket { - type Target = tokio::net::UdpSocket; + fn max_gso_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.max_gso_segments(), + Self::Closed { + last_max_gso_segments, + .. + } => *last_max_gso_segments, + } + } - fn deref(&self) -> &Self::Target { - self.0.as_ref().expect("only removed on drop") + fn gro_segments(&self) -> usize { + match self { + Self::Connected { state, .. } => state.gro_segments(), + Self::Closed { + last_gro_segments, .. + } => *last_gro_segments, + } } } impl Drop for UdpSocket { fn drop(&mut self) { - let std_sock = self.0.take().expect("not yet dropped").into_std(); + trace!("dropping UdpSocket"); + if let Some((socket, _)) = self.socket.write().unwrap().close() { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + // No wakeup after dropping write lock here, since we're getting dropped. + // this will be empty if `close` was called before + let std_sock = socket.into_std(); + handle.spawn_blocking(move || { + // Calls libc::close, which can block + drop(std_sock); + }); + } + } + } +} + +#[cfg(test)] +mod tests { + use anyhow::Context; + + use super::*; + + #[tokio::test] + async fn test_reconnect() -> anyhow::Result<()> { + let (s_b, mut r_b) = tokio::sync::mpsc::channel(16); + let handle_a = tokio::task::spawn(async move { + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr = socket.local_addr()?; + s_b.send(addr).await?; + println!("socket bound to {:?}", addr); + + let mut buffer = [0u8; 16]; + for i in 0..100 { + println!("-- tick {i}"); + let read = socket.recv_from(&mut buffer).await; + match read { + Ok((count, addr)) => { + println!("got {:?}", &buffer[..count]); + println!("sending {:?} to {:?}", &buffer[..count], addr); + socket.send_to(&buffer[..count], addr).await?; + } + Err(err) => { + eprintln!("error reading: {:?}", err); + } + } + } + socket.close().await; + anyhow::Ok(()) + }); + + let socket = UdpSocket::bind_local(IpFamily::V4, 0)?; + let first_addr = socket.local_addr()?; + println!("socket2 bound to {:?}", socket.local_addr()?); + let addr = r_b.recv().await.unwrap(); - // Only spawn_blocking if we are inside a tokio runtime, otherwise we just drop. - if let Ok(handle) = tokio::runtime::Handle::try_current() { - handle.spawn_blocking(move || { - // Calls libc::close, which can block - drop(std_sock); - }); + let mut buffer = [0u8; 16]; + for i in 0u8..100 { + println!("round one - {}", i); + socket.send_to(&[i][..], addr).await.context("send")?; + let (count, from) = socket.recv_from(&mut buffer).await.context("recv")?; + assert_eq!(addr, from); + assert_eq!(count, 1); + assert_eq!(buffer[0], i); + + // check for errors + assert!(!socket.is_broken()); + + // rebind + socket.rebind()?; + + // check that the socket has the same address as before + assert_eq!(socket.local_addr()?, first_addr); } + + handle_a.await.ok(); + + Ok(()) + } + + #[tokio::test] + async fn test_udp_mark_broken() -> anyhow::Result<()> { + let socket_a = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_a = socket_a.local_addr()?; + println!("socket bound to {:?}", addr_a); + + let socket_b = UdpSocket::bind_local(IpFamily::V4, 0)?; + let addr_b = socket_b.local_addr()?; + println!("socket bound to {:?}", addr_b); + + let handle = tokio::task::spawn(async move { + let mut buffer = [0u8; 16]; + for _ in 0..2 { + match socket_b.recv_from(&mut buffer).await { + Ok((count, addr)) => { + println!("got {:?} from {:?}", &buffer[..count], addr); + } + Err(err) => { + eprintln!("error recv: {:?}", err); + } + } + } + }); + socket_a.send_to(&[0][..], addr_b).await?; + socket_a.mark_broken(); + assert!(socket_a.is_broken()); + socket_a.send_to(&[0][..], addr_b).await?; + assert!(!socket_a.is_broken()); + + handle.await?; + Ok(()) } } diff --git a/net-tools/portmapper/src/nat_pmp.rs b/net-tools/portmapper/src/nat_pmp.rs index a44c4aeb7e..b859729923 100644 --- a/net-tools/portmapper/src/nat_pmp.rs +++ b/net-tools/portmapper/src/nat_pmp.rs @@ -51,7 +51,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -124,7 +124,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::Mapping { proto: MapProtocol::Udp, @@ -167,7 +167,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = Request::ExternalAddress; socket.send(&req.encode()).await?; diff --git a/net-tools/portmapper/src/pcp.rs b/net-tools/portmapper/src/pcp.rs index 0f2fe789f5..2019bc3ca5 100644 --- a/net-tools/portmapper/src/pcp.rs +++ b/net-tools/portmapper/src/pcp.rs @@ -54,7 +54,7 @@ impl Mapping { ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let mut nonce = [0u8; 12]; rand::thread_rng().fill_bytes(&mut nonce); @@ -144,7 +144,7 @@ impl Mapping { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let local_port = local_port.into(); let req = protocol::Request::mapping(nonce, local_port, local_ip, None, None, 0); @@ -188,7 +188,7 @@ async fn probe_available_fallible( ) -> anyhow::Result { // create the socket and send the request let socket = UdpSocket::bind_full((local_ip, 0))?; - socket.connect((gateway, protocol::SERVER_PORT)).await?; + socket.connect((gateway, protocol::SERVER_PORT).into())?; let req = protocol::Request::announce(local_ip.to_ipv6_mapped()); socket.send(&req.encode()).await?;