diff --git a/packages/centipede_control/src/lib.rs b/packages/centipede_control/src/lib.rs index 9ec004d..08c4b2f 100644 --- a/packages/centipede_control/src/lib.rs +++ b/packages/centipede_control/src/lib.rs @@ -1 +1,3 @@ +#![feature(btree_extract_if)] + pub mod pure; diff --git a/packages/centipede_control/src/pure.rs b/packages/centipede_control/src/pure.rs index 15c94b5..b0233af 100644 --- a/packages/centipede_control/src/pure.rs +++ b/packages/centipede_control/src/pure.rs @@ -4,7 +4,6 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, - mem, net::SocketAddr, ops::Deref, time::{Duration, SystemTime}, @@ -62,8 +61,8 @@ enum PeerState { /// Addresses on which we're listening for incoming packets, and will advertise to the peer. local_addrs: HashSet, - /// Addresses we've received heartbeats from, by the time they were last received. - tx_remote_addrs: BTreeMap>, + /// Addresses we've received messages from, by the time they were last received. + remote_addrs: BTreeMap>, /// The maximum time we're willing to wait between the peer's heartbeats. rx_max_heartbeat_interval: Duration, @@ -82,8 +81,11 @@ enum PeerState { /// Addresses on which we're listening for incoming packets, by the next time we should send a heartbeat. queued_heartbeats: BTreeMap>, - /// Addresses we've received heartbeats from, by the time they were last received. - tx_remote_addrs: BTreeMap>, + /// Addresses we've received messages from, by the time they were last received. + remote_addrs: BTreeMap>, + + /// Remote addresses to which we have send tunnels. + sending_to: HashSet, /// The maximum time we're willing to wait between the peer's heartbeats. rx_max_heartbeat_interval: Duration, @@ -163,17 +165,47 @@ impl Controller { *state_rx_local_addrs = local_addrs; } Some(PeerState::Connected { - local_addrs: old_local_addrs, + local_addrs: state_local_addrs, queued_heartbeats, .. }) => { - for addr in local_addrs { - // Check if the address is new. I.e., if it's already queued for a heartbeat. - // Otherwise, we immediately queue a heartbeat for it. - if !old_local_addrs.insert(addr) { - queued_heartbeats.entry(now).or_default().insert(addr); + // Iterate over all the new addresses. + for &addr in local_addrs.difference(&state_local_addrs) { + // Queue a heartbeat for each new address. + queued_heartbeats.entry(now).or_default().insert(addr); + } + + // Iterate over all the old addresses. + for &addr in state_local_addrs.difference(&local_addrs) { + // Remove the address from the router config if it's no longer in use. + let count = self + .recv_addrs + .get_mut(&addr) + .expect("local_addrs should only contain addresses in recv_addrs"); + *count -= 1; + if *count == 0 { + self.recv_addrs.remove(&addr); + + self.router_config.recv_addrs.remove(&addr); + self.router_config_changed = true; } + + // Remove the addresses from the heartbeat queue. + queued_heartbeats.values_mut().for_each(|addrs| { + addrs.remove(&addr); + }); + + // Remove the address from the send tunnel. + self.router_config + .send_tunnels + .get_mut(&public_key_to_peer_id(&public_key)) + .unwrap() + .links + .retain(|link| link.local != addr); + self.router_config_changed = true; } + + *state_local_addrs = local_addrs; } } } @@ -184,12 +216,12 @@ impl Controller { /// /// * `now` - the current time. /// * `public_key` - the public key of the peer. - /// * `tx_remote_addrs` - addresses to try to send initiation messages to. + /// * `remote_addrs` - addresses to try to send initiation messages to. pub fn initiate( &mut self, now: SystemTime, public_key: ed25519_dalek::VerifyingKey, - tx_remote_addrs: Vec, + remote_addrs: Vec, ) { // Get the old state, ensuring that we know the peer. let old_state = self @@ -214,7 +246,7 @@ impl Controller { }, ); for &local_addr in &local_addrs { - for &remote_addr in &tx_remote_addrs { + for &remote_addr in &remote_addrs { self.send_queue.push(OutgoingMessage { from: local_addr, to: remote_addr, @@ -231,7 +263,7 @@ impl Controller { ecdh_secret, local_addrs, // We shouldn't start sending to the remote addresses until we've received a response. - tx_remote_addrs: BTreeMap::new(), + remote_addrs: BTreeMap::new(), rx_max_heartbeat_interval, }, ); @@ -244,7 +276,41 @@ impl Controller { /// * `now` - the current time. /// * `public_key` - the public key of the peer. pub fn disconnect(&mut self, now: SystemTime, public_key: ed25519_dalek::VerifyingKey) { - todo!() + // Right now, we just clean up all references to the peer. + // In the future, we might want to also send a disconnect message to the peer. + + // Remove the control state. + let (local_addrs, _) = self + .peers + .remove(&public_key) + .expect("cannot disconnect from an already disconnected peer") + .forget_connection_and_destructure(); + + // Remove the send tunnel. + self.router_config + .send_tunnels + .remove(&public_key_to_peer_id(&public_key)); + + // Remove the recieve tunnel. + self.router_config + .recv_tunnels + .remove(&public_key_to_peer_id(&public_key)); + + // Remove any receive addresses that are no longer in use. + for addr in local_addrs { + let count = self + .recv_addrs + .get_mut(&addr) + .expect("local_addrs should only contain addresses in recv_addrs"); + + *count -= 1; + if *count == 0 { + self.recv_addrs.remove(&addr); + + self.router_config.recv_addrs.remove(&addr); + self.router_config_changed = true; + } + } } /// Handle an incoming message, transitioning the state machine. @@ -341,10 +407,12 @@ impl Controller { // Create the heartbeat queue, with the initial heartbeats queued. queued_heartbeats: [(now, local_addrs.clone())].into_iter().collect(), local_addrs, - // We have not yet received any heartbeats, but we know that we can send to the iniating address. - tx_remote_addrs: [(now, [incoming.from].into_iter().collect())] + // Because we received the initiate, we know the peer's address. + remote_addrs: [(now, [incoming.from].into_iter().collect())] .into_iter() .collect(), + // However, we shouldn't send packets to it until we've received an initiate acknowledgement and know that the peer knows the cipher. + sending_to: HashSet::new(), // Use the max heartbeat interval from the `listen` call. rx_max_heartbeat_interval, // Aim to beat three times in each interval, in case packets are dropped, but @@ -364,7 +432,7 @@ impl Controller { handshake_timestamp, ecdh_secret, local_addrs, - mut tx_remote_addrs, + mut remote_addrs, rx_max_heartbeat_interval, }, Content::InitiateAcknowledge { @@ -387,14 +455,39 @@ impl Controller { ); self.router_config_changed = true; - // The acknowledgement counts as a heartbeat, so we ensure its source is in the tx_remote_addrs. - for addr_set in tx_remote_addrs.values_mut() { + // Initialize sending tunnels to remote addresses. + // Note that all of the remote addresses in `remote_addrs` actually sent heartbeats, + // since any other message is either ignored or would have caused a state transition. + let sending_to = remote_addrs + .values() + .flatten() + .copied() + .collect::>(); + self.router_config.send_tunnels.insert( + public_key_to_peer_id(message.sender()), + centipede_router::config::SendTunnel { + cipher: cipher.clone(), + links: local_addrs + .iter() + .copied() + .flat_map(|local| { + sending_to + .iter() + .copied() + .map(move |remote| centipede_router::Link { local, remote }) + }) + .collect(), + }, + ); + self.router_config_changed = true; + + // The address that sent the initiate acknowledgement counts as a remote address. + // However, we shouldn't send packets to it until we know that the peer knows the cipher. + // Therefore, we don't add a link to the send tunnel yet. + for addr_set in remote_addrs.values_mut() { addr_set.remove(&incoming.from); } - tx_remote_addrs - .entry(now) - .or_default() - .insert(incoming.from); + remote_addrs.entry(now).or_default().insert(incoming.from); PeerState::Connected { handshake_timestamp, @@ -402,7 +495,8 @@ impl Controller { // Create the heartbeat queue, with the initial heartbeats queued. queued_heartbeats: [(now, local_addrs.clone())].into_iter().collect(), local_addrs, - tx_remote_addrs, + remote_addrs, + sending_to, rx_max_heartbeat_interval, // Aim to beat three times in each interval, in case packets are dropped, but // don't beat more than once per second. @@ -410,9 +504,42 @@ impl Controller { .min(Duration::from_secs(1)), } } + (old_state, Content::InitiateAcknowledge { .. }) => { + // We don't want to accept the incoming initiate acknowledgement, so we just put the old state back. + old_state + } - _ => todo!(), + (mut state, Content::Heartbeat) => { + // Delay expiration of the remote address. + if let PeerState::Connected { remote_addrs, .. } + | PeerState::Initiating { remote_addrs, .. } = &mut state + { + for addr_set in remote_addrs.values_mut() { + addr_set.remove(&incoming.from); + } + remote_addrs.entry(now).or_default().insert(incoming.from); + } + + // If we are connected, add a send link if we don't have one yet. + if let PeerState::Connected { sending_to, .. } = &mut state { + if sending_to.insert(incoming.from) { + self.router_config + .send_tunnels + .get_mut(&public_key_to_peer_id(message.sender())) + .unwrap() + .links + .insert(centipede_router::Link { + local: incoming.from, + remote: incoming.from, + }); + self.router_config_changed = true; + } + } + + state + } }; + self.peers.insert(*message.sender(), new_state); } @@ -427,7 +554,7 @@ impl Controller { if let PeerState::Connected { tx_heartbeat_interval, queued_heartbeats, - tx_remote_addrs, + remote_addrs, .. } = peer_state { @@ -439,16 +566,13 @@ impl Controller { { let (_, addrs) = queued_heartbeats.pop_first().unwrap(); + let message = Message::new(&self.private_key, *peer_key, Content::Heartbeat); for &local_addr in &addrs { - for &remote_addr in tx_remote_addrs.values().flatten() { + for &remote_addr in remote_addrs.values().flatten() { self.send_queue.push(OutgoingMessage { from: local_addr, to: remote_addr, - message: Message::new( - &self.private_key, - *peer_key, - Content::Heartbeat, - ), + message: message.clone(), }); } } @@ -461,38 +585,30 @@ impl Controller { .insert(now.checked_add(*tx_heartbeat_interval).unwrap(), to_requeue); } - // expire old tx_remote_addrs + // expire old remote_addrs if let PeerState::Initiating { - tx_remote_addrs, + remote_addrs, rx_max_heartbeat_interval, .. } | PeerState::Connected { - tx_remote_addrs, + remote_addrs, rx_max_heartbeat_interval, .. } = peer_state { - while let Some(first_entry) = tx_remote_addrs.first_entry() { - if now.duration_since(*first_entry.key()).unwrap() >= *rx_max_heartbeat_interval - { - let to_remove = first_entry.remove(); - for addr in to_remove { - let count = self.recv_addrs.get_mut(&addr).expect( - "tx_remote_addrs should only contain addresses in recv_addrs", - ); + let to_remove: HashSet<_> = remote_addrs + .extract_if(|t, _| *t + *rx_max_heartbeat_interval < now) + .flat_map(|(_, addrs)| addrs) + .collect(); - *count -= 1; - if *count == 0 { - self.recv_addrs.remove(&addr); - - self.router_config.recv_addrs.remove(&addr); - self.router_config_changed = true; - } - } - } else { - break; - } + if !to_remove.is_empty() { + self.router_config + .send_tunnels + .get_mut(&public_key_to_peer_id(peer_key)) + .unwrap() + .links + .retain(|link| !to_remove.contains(&link.remote)); } } } @@ -895,7 +1011,16 @@ mod tests { router_config .send_tunnels .get(&public_key_to_peer_id(&peer_key.verifying_key())) - .is_none(), + .is_some(), + "controller should have a recv tunnel after initiating and receiving an incoming initiate acknowledgement" + ); + assert!( + router_config + .send_tunnels + .get(&public_key_to_peer_id(&peer_key.verifying_key())) + .unwrap() + .links + .is_empty(), "controller cannot know where to send packets until receiving heartbeats" );