diff --git a/msim-tokio/Cargo.toml b/msim-tokio/Cargo.toml index 8bf5a5f..c0fc39b 100644 --- a/msim-tokio/Cargo.toml +++ b/msim-tokio/Cargo.toml @@ -62,3 +62,4 @@ real_tokio = { git = "ssh://git@github.com/iotaledger/tokio-madsim-fork.git", br bytes = { version = "1.7" } futures = { version = "0.3", features = ["async-await"] } mio = { version = "1.0" } +libc = "0.2" diff --git a/msim-tokio/src/sim/net.rs b/msim-tokio/src/sim/net.rs index fbcc0e1..d76d765 100644 --- a/msim-tokio/src/sim/net.rs +++ b/msim-tokio/src/sim/net.rs @@ -356,7 +356,7 @@ impl TcpStream { } async fn connect_addr(addr: impl ToSocketAddrs) -> io::Result { - let ep = Arc::new(Endpoint::connect(addr).await?); + let ep = Arc::new(Endpoint::connect(libc::SOCK_STREAM, addr).await?); trace!("connect {:?}", ep.local_addr()); let remote_sock = ep.peer_addr()?; @@ -720,7 +720,7 @@ impl TcpSocket { } pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> { - let ep = Endpoint::bind_sync(addr)?; + let ep = Endpoint::bind_sync(libc::SOCK_STREAM, addr)?; *self.bind_addr.lock().unwrap() = Some(ep.into()); Ok(()) } diff --git a/msim/src/sim/net/mod.rs b/msim/src/sim/net/mod.rs index 18d29ff..d6d8a81 100644 --- a/msim/src/sim/net/mod.rs +++ b/msim/src/sim/net/mod.rs @@ -328,7 +328,7 @@ unsafe fn accept_impl( ) -> libc::c_int { let result = HostNetworkState::with_socket( sock_fd, - |socket| -> Result { + |socket| -> Result<(SocketAddr, libc::c_int), (libc::c_int, libc::c_int)> { let node = plugin::node(); let net = plugin::simulator::(); let network = net.network.lock().unwrap(); @@ -346,7 +346,8 @@ unsafe fn accept_impl( // We can't simulate blocking accept in a single-threaded simulator, so if there is no // connection waiting for us, just bail. network - .accept_connect(node, endpoint.addr) + .accept_connect(socket.ty, node, endpoint.addr) + .map(|addr| (addr, socket.ty)) .ok_or((-1, libc::ECONNABORTED)) }, ) @@ -355,18 +356,18 @@ unsafe fn accept_impl( Result::Err((-1, libc::ENOTSOCK)) }); - let remote_addr = match result { + let (remote_addr, proto) = match result { Err((ret, err)) => { trace!("error status: {} {}", ret, err); set_errno(err); return ret; } - Ok(addr) => addr, + Ok(res) => res, }; write_socket_addr(address, address_len, remote_addr); - let endpoint = Endpoint::connect_sync(remote_addr) + let endpoint = Endpoint::connect_sync(proto, remote_addr) .expect("connection failure should already have been detected"); let fd = alloc_fd(); @@ -399,7 +400,7 @@ define_sys_interceptor!( HostNetworkState::with_socket(sock_fd, |socket| { assert!(socket.endpoint.is_none(), "socket already bound"); - match Endpoint::bind_sync(socket_addr) { + match Endpoint::bind_sync(socket.ty, socket_addr) { Ok(ep) => { socket.endpoint = Some(Arc::new(ep)); 0 @@ -441,7 +442,7 @@ define_sys_interceptor!( return Err((-1, libc::EISCONN)); } - let ep = Endpoint::connect_sync(sock_addr).map_err(|e| match e.kind() { + let ep = Endpoint::connect_sync(socket.ty, sock_addr).map_err(|e| match e.kind() { io::ErrorKind::AddrInUse => (-1, libc::EADDRINUSE), io::ErrorKind::AddrNotAvailable => (-1, libc::EADDRNOTAVAIL), _ => { @@ -456,7 +457,7 @@ define_sys_interceptor!( // the other end goes away). let net = plugin::simulator::(); let network = net.network.lock().unwrap(); - if !network.signal_connect(ep.addr, sock_addr) { + if !network.signal_connect(socket.ty, ep.addr, sock_addr) { return Err((-1, libc::ECONNREFUSED)); } @@ -1005,6 +1006,7 @@ pub struct Endpoint { net: Arc, node: NodeId, addr: SocketAddr, + proto: libc::c_int, peer: Option, live_tcp_ids: Mutex>, } @@ -1020,16 +1022,17 @@ impl std::fmt::Debug for Endpoint { } impl Endpoint { - /// Bind synchronously (for UDP) - pub fn bind_sync(addr: impl ToSocketAddrs) -> io::Result { + /// Bind synchronously + pub fn bind_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let addr = addr.to_socket_addrs()?.next().unwrap(); - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; let ep = Endpoint { net, node, addr, + proto, peer: None, live_tcp_ids: Default::default(), }; @@ -1053,30 +1056,31 @@ impl Endpoint { } /// Creates a [`Endpoint`] from the given address. - pub async fn bind(addr: impl ToSocketAddrs) -> io::Result { + pub async fn bind(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let addr = addr.to_socket_addrs()?.next().unwrap(); net.rand_delay().await; - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; Ok(Endpoint { net, node, addr, + proto, peer: None, live_tcp_ids: Default::default(), }) } /// Connects this [`Endpoint`] to a remote address. - pub async fn connect(addr: impl ToSocketAddrs) -> io::Result { + pub async fn connect(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); net.rand_delay().await; - Self::connect_sync(addr) + Self::connect_sync(proto, addr) } /// For libc::connect() - pub fn connect_sync(addr: impl ToSocketAddrs) -> io::Result { + pub fn connect_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result { let net = plugin::simulator::(); let node = plugin::node(); let peer = addr.to_socket_addrs()?.next().unwrap(); @@ -1085,11 +1089,12 @@ impl Endpoint { } else { SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)) }; - let addr = net.network.lock().unwrap().bind(node, addr)?; + let addr = net.network.lock().unwrap().bind(node, proto, addr)?; Ok(Endpoint { net, node, addr, + proto, peer: Some(peer), live_tcp_ids: Default::default(), }) @@ -1124,7 +1129,7 @@ impl Endpoint { .network .lock() .unwrap() - .deregister_tcp_id(self.node, remote_sock, id); + .deregister_tcp_id(self.node, self.proto, remote_sock, id); } /// Returns the local socket address. @@ -1230,7 +1235,7 @@ impl Endpoint { .network .lock() .unwrap() - .send(plugin::node(), self.addr, dst, tag, data) + .send(plugin::node(), self.proto, self.addr, dst, tag, data) } /// Receives a raw message. @@ -1240,12 +1245,12 @@ impl Endpoint { #[cfg_attr(docsrs, doc(cfg(msim)))] pub async fn recv_from_raw(&self, tag: u64) -> io::Result<(Payload, SocketAddr)> { trace!("awaiting recv: {} tag={:x}", self.addr, tag); - let recver = self - .net - .network - .lock() - .unwrap() - .recv(plugin::node(), self.addr, tag); + let recver = + self.net + .network + .lock() + .unwrap() + .recv(plugin::node(), self.proto, self.addr, tag); let msg = recver .await .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "network is down"))?; @@ -1262,7 +1267,7 @@ impl Endpoint { .network .lock() .unwrap() - .recv_sync(plugin::node(), self.addr, tag) + .recv_sync(plugin::node(), self.proto, self.addr, tag) .ok_or_else(|| io::Error::new(io::ErrorKind::WouldBlock, "recv call would blck"))?; trace!( @@ -1316,12 +1321,13 @@ impl Endpoint { /// Check if there is a message waiting that can be received without blocking. /// If not, schedule a wakeup using the context. pub fn recv_ready(&self, cx: Option<&mut Context<'_>>, tag: u64) -> io::Result { - Ok(self - .net - .network - .lock() - .unwrap() - .recv_ready(cx, plugin::node(), self.addr, tag)) + Ok(self.net.network.lock().unwrap().recv_ready( + cx, + plugin::node(), + self.proto, + self.addr, + tag, + )) } } @@ -1334,7 +1340,7 @@ impl Drop for Endpoint { // avoid panic on panicking if let Ok(mut network) = self.net.network.lock() { - network.close(self.node, self.addr); + network.close(self.proto, self.node, self.addr); } } } @@ -1368,7 +1374,7 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); barrier_.wait().await; net.send_to(addr2, 1, payload!(vec![1])).await.unwrap(); @@ -1378,7 +1384,7 @@ mod tests { }); let f = node2.spawn(async move { - let net = Endpoint::bind(addr2).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap(); barrier.wait().await; let mut buf = vec![0; 0x10]; @@ -1407,14 +1413,14 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); barrier_.wait().await; net.send_to(addr2, 1, payload!(vec![1])).await.unwrap(); }); let f = node2.spawn(async move { - let net = Endpoint::bind(addr2).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap(); let mut buf = vec![0; 0x10]; timeout(Duration::from_secs(1), net.recv_from(1, &mut buf)) .await @@ -1439,7 +1445,7 @@ mod tests { let node1 = runtime.create_node().ip(addr1.ip()).build(); let f = node1.spawn(async move { - let net = Endpoint::bind(addr1).await.unwrap(); + let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); let err = net.recv_from(1, &mut []).await.unwrap_err(); assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe); // FIXME: should still error @@ -1462,36 +1468,47 @@ mod tests { let f = node.spawn(async move { // unspecified - let ep = Endpoint::bind("0.0.0.0:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "0.0.0.0:0") + .await + .unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip(), ip); assert_ne!(addr.port(), 0); // unspecified v6 - let ep = Endpoint::bind(":::0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, ":::0").await.unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip(), ip); assert_ne!(addr.port(), 0); // localhost - let ep = Endpoint::bind("127.0.0.1:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:0") + .await + .unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip().to_string(), "127.0.0.1"); assert_ne!(addr.port(), 0); // localhost v6 - let ep = Endpoint::bind("::1:0").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "::1:0").await.unwrap(); let addr = ep.local_addr().unwrap(); assert_eq!(addr.ip().to_string(), "::1"); assert_ne!(addr.port(), 0); // wrong IP - let err = Endpoint::bind("10.0.0.2:0").await.err().unwrap(); + let err = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.2:0") + .await + .err() + .unwrap(); assert_eq!(err.kind(), std::io::ErrorKind::AddrNotAvailable); // drop and reuse port - let _ = Endpoint::bind("10.0.0.1:100").await.unwrap(); - let _ = Endpoint::bind("10.0.0.1:100").await.unwrap(); + let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100") + .await + .unwrap(); + let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100") + .await + .unwrap(); }); runtime.block_on(f).unwrap(); } @@ -1508,8 +1525,12 @@ mod tests { let barrier_ = barrier.clone(); let f1 = node1.spawn(async move { - let ep1 = Endpoint::bind("127.0.0.1:1").await.unwrap(); - let ep2 = Endpoint::bind("10.0.0.1:2").await.unwrap(); + let ep1 = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1") + .await + .unwrap(); + let ep2 = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:2") + .await + .unwrap(); barrier_.wait().await; // FIXME: ep1 should not receive messages from other node @@ -1521,7 +1542,9 @@ mod tests { ep2.recv_from(1, &mut []).await.unwrap(); }); let f2 = node2.spawn(async move { - let ep = Endpoint::bind("127.0.0.1:1").await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1") + .await + .unwrap(); barrier.wait().await; ep.send_to("10.0.0.1:1", 1, payload!(vec![1])) @@ -1546,7 +1569,7 @@ mod tests { let barrier_ = barrier.clone(); node1.spawn(async move { - let ep = Endpoint::bind(addr1).await.unwrap(); + let ep = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap(); assert_eq!(ep.local_addr().unwrap(), addr1); barrier_.wait().await; @@ -1561,7 +1584,7 @@ mod tests { let f = node2.spawn(async move { barrier.wait().await; - let ep = Endpoint::connect(addr1).await.unwrap(); + let ep = Endpoint::connect(libc::SOCK_STREAM, addr1).await.unwrap(); assert_eq!(ep.peer_addr().unwrap(), addr1); ep.send(1, payload!(b"ping".to_vec())).await.unwrap(); diff --git a/msim/src/sim/net/network.rs b/msim/src/sim/net/network.rs index ca0f68b..3a3ac02 100644 --- a/msim/src/sim/net/network.rs +++ b/msim/src/sim/net/network.rs @@ -33,7 +33,7 @@ struct Node { /// NOTE: now a node can have at most one IP address. ip: Option, /// Sockets in the node. - sockets: HashMap>>, + sockets: HashMap>>, /// live tcp connections. live_tcp_ids: HashSet, @@ -67,6 +67,17 @@ pub struct Stat { pub msg_count: u64, } +#[derive(Debug, Hash, Eq, PartialEq)] +struct SocketKey(u16, libc::c_int); + +fn proto_str(proto: libc::c_int) -> &'static str { + match proto { + libc::SOCK_STREAM => "tcp", + libc::SOCK_DGRAM => "udp", + _ => panic!("unsupported socket type {}", proto), + } +} + impl Network { pub fn new(rand: GlobalRng, time: TimeHandle, config: NetworkConfig) -> Self { Self { @@ -179,8 +190,13 @@ impl Network { self.clogged_link.remove(&(src, dst)); } - pub fn bind(&mut self, node_id: NodeId, mut addr: SocketAddr) -> io::Result { - debug!("binding: {addr} -> {node_id}"); + pub fn bind( + &mut self, + node_id: NodeId, + proto: libc::c_int, + mut addr: SocketAddr, + ) -> io::Result { + debug!("binding ({}): {addr} -> {node_id}", proto_str(proto)); let node = self.nodes.get_mut(&node_id).expect("node not found"); // resolve IP if unspecified if addr.ip().is_unspecified() { @@ -200,7 +216,7 @@ impl Network { if addr.port() == 0 { let next_ephemeral_port = node.next_ephemeral_port; let port = (next_ephemeral_port..=u16::MAX) - .find(|port| !node.sockets.contains_key(port)) + .find(|port| !node.sockets.contains_key(&SocketKey(*port, proto))) .ok_or_else(|| { warn!("ephemeral ports exhausted"); io::Error::new(io::ErrorKind::AddrInUse, "no available ephemeral port") @@ -210,7 +226,7 @@ impl Network { addr.set_port(port); } // insert socket - match node.sockets.entry(addr.port()) { + match node.sockets.entry(SocketKey(addr.port(), proto)) { Entry::Occupied(_) => { warn!("bind() error: address already in use: {addr:?}"); return Err(io::Error::new( @@ -239,7 +255,13 @@ impl Network { ); } - pub fn deregister_tcp_id(&mut self, node: NodeId, remote_addr: &SocketAddr, tcp_id: u32) { + pub fn deregister_tcp_id( + &mut self, + node: NodeId, + proto: libc::c_int, + remote_addr: &SocketAddr, + tcp_id: u32, + ) { trace!("deregistering tcp id {} for node {}", tcp_id, node); // node may have been deleted @@ -262,7 +284,7 @@ impl Network { if let Some(socket) = self .nodes .get_mut(node_id) - .map(|node| node.sockets.get(&remote_addr.port())) + .map(|node| node.sockets.get(&SocketKey(remote_addr.port(), proto))) .tap_none(|| debug!("No node found for {node_id}")) .flatten() { @@ -280,14 +302,14 @@ impl Network { } } - pub fn signal_connect(&self, src: SocketAddr, dst: SocketAddr) -> bool { + pub fn signal_connect(&self, proto: libc::c_int, src: SocketAddr, dst: SocketAddr) -> bool { let node = self.get_node_for_addr(&dst.ip()); if node.is_none() { return false; } let node = node.unwrap(); - let dst_socket = self.nodes[&node].sockets.get(&dst.port()); + let dst_socket = self.nodes[&node].sockets.get(&SocketKey(dst.port(), proto)); if let Some(dst_socket) = dst_socket { dst_socket.lock().unwrap().signal_connect(src); @@ -297,22 +319,31 @@ impl Network { } } - pub fn accept_connect(&self, node: NodeId, listening: SocketAddr) -> Option { - let socket = self.nodes[&node].sockets.get(&listening.port()).unwrap(); + pub fn accept_connect( + &self, + proto: libc::c_int, + node: NodeId, + listening: SocketAddr, + ) -> Option { + let socket = self.nodes[&node] + .sockets + .get(&SocketKey(listening.port(), proto)) + .unwrap(); socket.lock().unwrap().accept_connect() } - pub fn close(&mut self, node_id: NodeId, addr: SocketAddr) { + pub fn close(&mut self, proto: libc::c_int, node_id: NodeId, addr: SocketAddr) { if let Some(node) = self.nodes.get_mut(&node_id) { debug!("close: {node_id} {addr}"); // TODO: simulate TIME_WAIT? - node.sockets.remove(&addr.port()); + node.sockets.remove(&SocketKey(addr.port(), proto)); } } pub fn send( &mut self, node_id: NodeId, + proto: libc::c_int, src: SocketAddr, dst: SocketAddr, tag: u64, @@ -383,7 +414,7 @@ impl Network { } } - let mailbox = match node.sockets.get(&dst.port()) { + let mailbox = match node.sockets.get(&SocketKey(dst.port(), proto)) { Some(mailbox) => Arc::downgrade(mailbox), None => { debug!("destination port not available: {dst}"); @@ -420,15 +451,27 @@ impl Network { Ok(()) } - pub fn recv(&mut self, node: NodeId, dst: SocketAddr, tag: u64) -> oneshot::Receiver { - self.nodes[&node].sockets[&dst.port()] + pub fn recv( + &mut self, + node: NodeId, + proto: libc::c_int, + dst: SocketAddr, + tag: u64, + ) -> oneshot::Receiver { + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv(tag) } - pub fn recv_sync(&mut self, node: NodeId, dst: SocketAddr, tag: u64) -> Option { - self.nodes[&node].sockets[&dst.port()] + pub fn recv_sync( + &mut self, + node: NodeId, + proto: libc::c_int, + dst: SocketAddr, + tag: u64, + ) -> Option { + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv_sync(tag) @@ -438,10 +481,11 @@ impl Network { &self, cx: Option<&mut Context<'_>>, node: NodeId, + proto: libc::c_int, dst: SocketAddr, tag: u64, ) -> bool { - self.nodes[&node].sockets[&dst.port()] + self.nodes[&node].sockets[&SocketKey(dst.port(), proto)] .lock() .unwrap() .recv_ready(cx, tag) diff --git a/msim/src/sim/time/interval.rs b/msim/src/sim/time/interval.rs index 6db4c0c..636827a 100644 --- a/msim/src/sim/time/interval.rs +++ b/msim/src/sim/time/interval.rs @@ -167,17 +167,42 @@ impl Interval { /// Resets the interval to complete one period after the current time. /// /// This method ignores [`MissedTickBehavior`] strategy. + /// + /// This is equivalent to calling `reset_at(Instant::now() + period)`. pub fn reset(&mut self) { self.delay.as_mut().reset(Instant::now() + self.period); } - /// Resets the interval after the specified [`std::time::Duration`] + /// Resets the interval immediately. + /// + /// This method ignores [`MissedTickBehavior`] strategy. + /// + /// This is equivalent to calling `reset_at(Instant::now())`. + pub fn reset_immediately(&mut self) { + self.delay.as_mut().reset(Instant::now()); + } + + /// Resets the interval after the specified [`std::time::Duration`]. /// /// This method ignores [`MissedTickBehavior`] strategy. + /// + /// This is equivalent to calling `reset_at(Instant::now() + after)`. pub fn reset_after(&mut self, after: Duration) { self.delay.as_mut().reset(Instant::now() + after); } + /// Resets the interval to a [`crate::time::Instant`] deadline. + /// + /// Sets the next tick to expire at the given instant. If the instant is in + /// the past, then the [`MissedTickBehavior`] strategy will be used to + /// catch up. If the instant is in the future, then the next tick will + /// complete at the given instant, even if that means that it will sleep for + /// longer than the duration of this [`Interval`]. If the [`Interval`] had + /// any missed ticks before calling this method, then those are discarded. + pub fn reset_at(&mut self, deadline: Instant) { + self.delay.as_mut().reset(deadline); + } + /// Returns the [`MissedTickBehavior`] strategy currently being used. pub fn missed_tick_behavior(&self) -> MissedTickBehavior { self.missed_tick_behavior