From 430335b304b37c2d48bd693723f663bb013b6037 Mon Sep 17 00:00:00 2001 From: Max Niederman Date: Wed, 17 Jan 2024 22:26:40 -0800 Subject: [PATCH] feat(router/controller): add transactional semantics --- packages/centipede_router/src/controller.rs | 182 ++++++++++-------- packages/centipede_router/src/lib.rs | 6 +- .../centipede_router/tests/udp_threads.rs | 16 +- 3 files changed, 118 insertions(+), 86 deletions(-) diff --git a/packages/centipede_router/src/controller.rs b/packages/centipede_router/src/controller.rs index 1c138d2..33c7c2b 100644 --- a/packages/centipede_router/src/controller.rs +++ b/packages/centipede_router/src/controller.rs @@ -1,4 +1,7 @@ -use std::sync::{atomic::AtomicU64, Arc}; +use std::{ + net::SocketAddr, + sync::{atomic::AtomicU64, Arc}, +}; use chacha20poly1305::ChaCha20Poly1305; @@ -18,47 +21,52 @@ impl<'r> Controller<'r> { Self { router } } - /// Reconfigure the router by applying a function to the current configured state. - fn reconfigure(&mut self, f: impl FnOnce(&ConfiguredRouter) -> ConfiguredRouter) { - let prev = self.router.state.load(); - let next = f(prev.as_ref()); + /// Complete a transaction on the router. + /// + /// This function is the only way to mutate the router's state, + /// and there can only be one controller for a router at a time. + /// This guarantees that the state cannot be mutated concurrently. + pub fn transaction(&mut self, f: impl FnOnce(&mut Transaction) -> R) -> R { + let mut transaction = Transaction { + config: (*self.router.state.load_full()).clone(), + }; + let ret = f(&mut transaction); - self.router.state.store(Arc::new(next)); + transaction.config.generation = transaction.config.generation.wrapping_add(1); + self.router.state.store(Arc::new(transaction.config)); + + ret } +} - /// Reconfigure the router by cloning the current configured state and mutating it. - fn reconfigure_mutate(&mut self, f: impl FnOnce(&mut ConfiguredRouter)) { - self.reconfigure(|prev| { - let mut next = prev.clone(); - f(&mut next); - next - }) +pub struct Transaction { + config: ConfiguredRouter, +} + +impl Transaction { + /// Update the addresses on which to listen. + pub fn set_recv_addrs(&mut self, addrs: Vec) { + self.config.recv_addrs = addrs; } /// Insert or update a receive tunnel. pub fn upsert_receive_tunnel(&mut self, sender_id: PeerId, cipher: ChaCha20Poly1305) { - self.reconfigure_mutate(move |state| { - if let Some(tunnel) = state.recv_tunnels.get_mut(&sender_id) { - tunnel.cipher = cipher; - } else { - state.recv_tunnels.insert( - sender_id, - RecvTunnel { - cipher, - memory: Arc::new(PacketMemory::default()), - }, - ); - } - increment_generation(state); - }); + if let Some(tunnel) = self.config.recv_tunnels.get_mut(&sender_id) { + tunnel.cipher = cipher; + } else { + self.config.recv_tunnels.insert( + sender_id, + RecvTunnel { + cipher, + memory: Arc::new(PacketMemory::default()), + }, + ); + } } /// Delete a receive tunnel. pub fn delete_receive_tunnel(&mut self, sender_id: PeerId) { - self.reconfigure_mutate(move |state| { - state.recv_tunnels.remove(&sender_id); - increment_generation(state); - }); + self.config.recv_tunnels.remove(&sender_id); } /// Insert or update a send tunnel. @@ -68,37 +76,27 @@ impl<'r> Controller<'r> { cipher: ChaCha20Poly1305, links: Vec, ) { - self.reconfigure_mutate(move |state| { - if let Some(tunnel) = state.send_tunnels.get_mut(&receiver_id) { - tunnel.cipher = cipher; - tunnel.links = links; - } else { - state.send_tunnels.insert( - receiver_id, - SendTunnel { - links, - cipher, - next_sequence_number: Arc::new(AtomicU64::new(0)), - }, - ); - } - increment_generation(state); - }); + if let Some(tunnel) = self.config.send_tunnels.get_mut(&receiver_id) { + tunnel.cipher = cipher; + tunnel.links = links; + } else { + self.config.send_tunnels.insert( + receiver_id, + SendTunnel { + links, + cipher, + next_sequence_number: Arc::new(AtomicU64::new(0)), + }, + ); + } } /// Delete a send tunnel. pub fn delete_send_tunnel(&mut self, receiver_id: PeerId) { - self.reconfigure_mutate(move |state| { - state.send_tunnels.remove(&receiver_id); - increment_generation(state); - }); + self.config.send_tunnels.remove(&receiver_id); } } -fn increment_generation(state: &mut ConfiguredRouter) { - state.generation = state.generation.wrapping_add(1); -} - #[cfg(test)] mod tests { use std::{net::SocketAddr, sync::atomic::Ordering}; @@ -111,7 +109,7 @@ mod tests { #[test] fn construct() { - Router::new([0; 8]); + Router::new([0; 8], vec![]); } fn state(controller: &Controller) -> Arc { @@ -119,23 +117,42 @@ mod tests { } #[test] - fn crud_receive_tunnel() { - let mut router = Router::new([0; 8]); + fn set_recv_addrs() { + let mut router = Router::new([0; 8], vec![]); let (mut controller, _) = router.handles(0); - controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into())); - controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into())); + assert_eq!(state(&controller).recv_addrs, vec![]); + let prev_generation = state(&controller).generation; - assert!(state(&controller).recv_tunnels.contains_key(&[1; 8])); + controller.transaction(|trans| { + trans.set_recv_addrs(vec![SocketAddr::from(([0, 0, 0, 0], 0))]); + }); - controller.delete_receive_tunnel([1; 8]); + assert_eq!( + state(&controller).recv_addrs, + vec![SocketAddr::from(([0, 0, 0, 0], 0))] + ); + assert_ne!(state(&controller).generation, prev_generation); + } + #[test] + fn crud_receive_tunnel() { + let mut router = Router::new([0; 8], vec![]); + let (mut controller, _) = router.handles(0); + + controller.transaction(|trans| { + trans.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into())); + trans.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into())); + }); + assert!(state(&controller).recv_tunnels.contains_key(&[1; 8])); + + controller.transaction(|trans| trans.delete_receive_tunnel([1; 8])); assert!(!state(&controller).recv_tunnels.contains_key(&[1; 8])); } #[test] fn crud_send_tunnel() { - let mut router = Router::new([0; 8]); + let mut router = Router::new([0; 8], vec![]); let (mut controller, _) = router.handles(0); let link = Link { @@ -143,29 +160,36 @@ mod tests { remote: SocketAddr::from(([0, 0, 0, 1], 1)), }; - controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![link]); - + controller.transaction(|trans| { + trans.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![link]) + }); assert_eq!(state(&controller).send_tunnels[&[1; 8]].links, vec![link]); - controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![]); - - assert_eq!(state(&controller).send_tunnels[&[1; 8]].links, vec![]); - - controller.delete_send_tunnel([1; 8]); + controller.transaction(|trans| { + trans.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![link]); + }); + assert_eq!(state(&controller).send_tunnels[&[1; 8]].links, vec![link]); - assert!(!state(&controller).send_tunnels.contains_key(&[1; 8])); + controller.transaction(|trans| { + trans.delete_send_tunnel([1; 8]); + }); + assert!(state(&controller).send_tunnels.is_empty()); } #[test] fn receive_updates_preserve_state() { - let mut router = Router::new([0; 8]); + let mut router = Router::new([0; 8], vec![]); let (mut controller, _) = router.handles(0); - controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into())); + controller.transaction(|trans| { + trans.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into())); + }); state(&controller).recv_tunnels[&[1; 8]].memory.observe(0); - controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into())); + controller.transaction(|trans| { + trans.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into())); + }); assert_eq!( state(&controller).recv_tunnels[&[1; 8]].memory.observe(0), @@ -175,16 +199,20 @@ mod tests { #[test] fn send_updates_preserve_state() { - let mut router = Router::new([0; 8]); + let mut router = Router::new([0; 8], vec![]); let (mut controller, _) = router.handles(0); - controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![]); + controller.transaction(|trans| { + trans.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![]); + }); state(&controller).send_tunnels[&[1; 8]] .next_sequence_number .store(1, Ordering::SeqCst); - controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![]); + controller.transaction(|trans| { + trans.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![]) + }); assert_eq!( state(&controller).send_tunnels[&[1; 8]] diff --git a/packages/centipede_router/src/lib.rs b/packages/centipede_router/src/lib.rs index 4b95bca..f7a928a 100644 --- a/packages/centipede_router/src/lib.rs +++ b/packages/centipede_router/src/lib.rs @@ -31,6 +31,9 @@ struct ConfiguredRouter { /// Our local peer identifier. local_id: PeerId, + /// Addresses on which to listen for incoming packets. + recv_addrs: Vec, + /// Set of receiving tunnels, by sender identifier. recv_tunnels: HashMap, @@ -75,11 +78,12 @@ pub type PeerId = [u8; 8]; impl Router { /// Create a new router. - pub fn new(peer_id: PeerId) -> Self { + pub fn new(peer_id: PeerId, recv_addrs: Vec) -> Self { Self { state: ArcSwap::from_pointee(ConfiguredRouter { generation: 0, local_id: peer_id, + recv_addrs, recv_tunnels: HashMap::new(), send_tunnels: HashMap::new(), }), diff --git a/packages/centipede_router/tests/udp_threads.rs b/packages/centipede_router/tests/udp_threads.rs index 46516ec..9847e97 100644 --- a/packages/centipede_router/tests/udp_threads.rs +++ b/packages/centipede_router/tests/udp_threads.rs @@ -53,12 +53,10 @@ fn half_duplex_single_message() { id: [0; 8], addr_count: 1, entrypoint: Box::new(|mut ctx: PeerCtx<'_>| { - ctx.controller.upsert_send_tunnel( - [1; 8], - dummy_cipher(), - ctx.possible_links_to([1; 8]), - ); - + let links = ctx.possible_links_to([1; 8]); + ctx.controller.transaction(move |trans| { + trans.upsert_send_tunnel([1; 8], dummy_cipher(), links.clone()) + }); let mut obligations = ctx.worker.handle_outgoing(PACKET); let mut scratch = Vec::new(); @@ -82,7 +80,8 @@ fn half_duplex_single_message() { id: [1; 8], addr_count: 1, entrypoint: Box::new(|mut ctx: PeerCtx<'_>| { - ctx.controller.upsert_receive_tunnel([0; 8], dummy_cipher()); + ctx.controller + .transaction(|trans| trans.upsert_receive_tunnel([0; 8], dummy_cipher())); let packets = ctx.receive_block(); assert_eq!(packets.len(), 1, "received wrong number of packets"); @@ -148,7 +147,8 @@ impl<'r> PeerCtx<'r> { let sockets = sockets.remove(&spec.id).unwrap(); s.spawn(move || { - let mut router = Router::new(spec.id); + let mut router = + Router::new(spec.id, peer_addrs.get(&spec.id).unwrap().clone()); let (controller, workers) = router.handles(1); let worker = workers.into_iter().next().unwrap();