Skip to content

Commit

Permalink
handle router reconfiguration in centipede_router
Browse files Browse the repository at this point in the history
  • Loading branch information
max-niederman committed Dec 30, 2023
1 parent f920832 commit 9899160
Show file tree
Hide file tree
Showing 7 changed files with 450 additions and 253 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

84 changes: 69 additions & 15 deletions packages/centipede_proto/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,15 @@ where
}
}

/// Interpret the message buffer as a byte slice.
pub fn as_ref(&self) -> Message<&'_ [u8], A, T> {
Message {
buffer: &self.buffer,
_auth: PhantomData,
_text: PhantomData,
}
}

/// Deconstruct the message into its underlying buffer.
pub fn to_buffer(self) -> B {
self.buffer
Expand Down Expand Up @@ -205,7 +214,10 @@ where
}

/// Encrypt the message, fill in its tag, and return its buffer.
pub fn encrypt_in_place(mut self, cipher: &ChaCha20Poly1305) -> B
pub fn encrypt_in_place(
mut self,
cipher: &ChaCha20Poly1305,
) -> Message<B, auth::Valid, text::Ciphertext>
where
B: DerefMut,
{
Expand All @@ -217,11 +229,20 @@ where

header[TAG_RANGE].copy_from_slice(&tag);

self.buffer
Message {
buffer: self.buffer,
_auth: PhantomData,
_text: PhantomData,
}
}
}

impl Message<Vec<u8>, auth::Valid, text::Plaintext> {
/// Measure the buffer size needed to hold a message with the given packet size.
pub const fn measure(packet_size: usize) -> usize {
PACKET_RANGE.start + packet_size
}

/// Create a new message with an empty packet in a scratch buffer using the given metadata.
pub fn new_in(sequence_number: u64, sender: [u8; 8], mut buffer: Vec<u8>) -> Self {
buffer.clear();
Expand All @@ -232,19 +253,33 @@ impl Message<Vec<u8>, auth::Valid, text::Plaintext> {
unsafe { Self::from_buffer_unchecked(buffer) }
}

/// Create a new message with an empty packet backed by a [`Vec<u8>`] using the given metadata.
pub fn new(sequence_number: u64, sender: [u8; 8]) -> Self {
/// Create a new message backed by a [`Vec<u8>`] with capacity for the given packet size using the given metadata.
pub fn new_with_capacity(sequence_number: u64, sender: [u8; 8], packet_size: usize) -> Self {
Self::new_in(
sequence_number,
sender,
Vec::with_capacity(PACKET_RANGE.start),
Vec::with_capacity(PACKET_RANGE.start + packet_size),
)
}

/// Overwrite the message's packet data from a reader.
pub fn overwrite_packet<R: io::Read>(&mut self, mut reader: R) -> io::Result<u64> {
/// Create a new message backed by a [`Vec<u8>`] using the given metadata.
pub fn new(sequence_number: u64, sender: [u8; 8]) -> Self {
Self::new_with_capacity(sequence_number, sender, 0)
}

/// Overwrite the message's packet data from an iterator.
pub fn overwrite_packet(&mut self, iter: impl IntoIterator<Item = u8>) {
let iter = iter.into_iter();

self.reserve_packet(iter.size_hint().0);

self.buffer.truncate(PACKET_RANGE.start);
io::copy(&mut reader, &mut self.buffer)
self.buffer.extend(iter);
}

/// Reserve space for a packet of the given size.
pub fn reserve_packet(&mut self, size: usize) {
self.buffer.reserve(PACKET_RANGE.start + size);
}

/// Allocate space for a packet of the given size and return a mutable reference to it.
Expand Down Expand Up @@ -280,6 +315,23 @@ const NONCE_RANGE: Range<usize> = 0..12;
const TAG_RANGE: Range<usize> = 16..32;
const PACKET_RANGE: RangeFrom<usize> = 32..;

/// A mutable reference to a [`Vec<u8>`] that can be used as a packet message buffer.
pub struct ByteVecMut<'v>(pub &'v mut Vec<u8>);

impl<'v> Deref for ByteVecMut<'v> {
type Target = [u8];

fn deref(&self) -> &Self::Target {
self.0
}
}

impl<'v> DerefMut for ByteVecMut<'v> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}

/// An error representing a failure to parse a packet message.
#[derive(Debug, Error)]
pub enum ParseError {
Expand Down Expand Up @@ -318,7 +370,7 @@ mod tests {
assert_eq!(message.claimed_sender(), [1; 8]);
assert_eq!(message.claimed_packet_plaintext(), &[]);

message.overwrite_packet(PACKET).unwrap();
message.overwrite_packet(PACKET.iter().copied());

assert_eq!(message.claimed_packet_plaintext(), b"hello world");

Expand All @@ -333,12 +385,11 @@ mod tests {
fn encrypt_and_decrypt_in_place() {
let mut message = Message::new(1729, [42; 8]);

message.overwrite_packet(PACKET).unwrap();
message.overwrite_packet(PACKET.iter().copied());

let cipher = ChaCha20Poly1305::new((&[42; 32]).into());

let ciphertext_raw = message.encrypt_in_place(&cipher);
let ciphertext_message = Message::from_buffer(ciphertext_raw).unwrap();
let ciphertext_message = message.encrypt_in_place(&cipher);

assert_eq!(ciphertext_message.claimed_sequence_number(), 1729);
assert_eq!(ciphertext_message.claimed_sender(), [42; 8]);
Expand All @@ -353,10 +404,13 @@ mod tests {

#[test]
fn discriminate_packet() {
let message = Message::new(42, [1; 8]);
let plaintext = Message::new(42, [1; 8]);
let cipher = ChaCha20Poly1305::new((&[42; 32]).into());
let buffer = message.encrypt_in_place(&cipher);
let ciphertext = plaintext.encrypt_in_place(&cipher);

assert_eq!(discriminate(buffer).unwrap(), MessageDiscriminant::Packet);
assert_eq!(
discriminate(ciphertext.as_buffer().as_slice()).unwrap(),
MessageDiscriminant::Packet
);
}
}
1 change: 1 addition & 0 deletions packages/centipede_router/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
arc-swap = "1.6.0"
centipede_proto = { version = "0.1.0", path = "../centipede_proto" }
chacha20poly1305 = "0.10.1"
196 changes: 196 additions & 0 deletions packages/centipede_router/src/controller.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use std::sync::{atomic::AtomicU64, Arc};

use chacha20poly1305::ChaCha20Poly1305;

use crate::{
packet_memory::PacketMemory, ConfiguredRouter, Link, PeerId, RecvTunnel, Router, SendTunnel,
};

pub struct Controller<'r> {
router: &'r Router,
}

impl<'r> Controller<'r> {
/// Create a new controller, given a router.
///
/// It is a logic error to create a controller for a router when there is already a controller for that router.
pub(crate) fn new(router: &'r Router) -> Self {
Self { router }
}

/// Reconfigure the router by applying a function to the current configured state.
fn reconfigure(&mut self, f: impl FnOnce(&ConfiguredRouter) -> ConfiguredRouter) {
let prev = self.router.state.load();
let next = f(prev.as_ref());

self.router.state.store(Arc::new(next));
}

/// Reconfigure the router by cloning the current configured state and mutating it.
fn reconfigure_mutate(&mut self, f: impl FnOnce(&mut ConfiguredRouter)) {
self.reconfigure(|prev| {
let mut next = prev.clone();
f(&mut next);
next
})
}

/// Insert or update a receive tunnel.
pub fn upsert_receive_tunnel(&mut self, sender_id: PeerId, cipher: ChaCha20Poly1305) {
self.reconfigure_mutate(move |state| {
if let Some(tunnel) = state.recv_tunnels.get_mut(&sender_id) {
tunnel.cipher = cipher;
} else {
state.recv_tunnels.insert(
sender_id,
RecvTunnel {
cipher,
memory: Arc::new(PacketMemory::default()),
},
);
}
increment_generation(state);
});
}

/// Delete a receive tunnel.
pub fn delete_receive_tunnel(&mut self, sender_id: PeerId) {
self.reconfigure_mutate(move |state| {
state.recv_tunnels.remove(&sender_id);
increment_generation(state);
});
}

/// Insert or update a send tunnel.
pub fn upsert_send_tunnel(
&mut self,
receiver_id: PeerId,
cipher: ChaCha20Poly1305,
links: Vec<Link>,
) {
self.reconfigure_mutate(move |state| {
if let Some(tunnel) = state.send_tunnels.get_mut(&receiver_id) {
tunnel.cipher = cipher;
tunnel.links = links;
} else {
state.send_tunnels.insert(
receiver_id,
SendTunnel {
links,
cipher,
next_sequence_number: Arc::new(AtomicU64::new(0)),
},
);
}
increment_generation(state);
});
}

/// Delete a send tunnel.
pub fn delete_send_tunnel(&mut self, receiver_id: PeerId) {
self.reconfigure_mutate(move |state| {
state.send_tunnels.remove(&receiver_id);
increment_generation(state);
});
}
}

fn increment_generation(state: &mut ConfiguredRouter) {
state.generation = state.generation.wrapping_add(1);
}

#[cfg(test)]
mod tests {
use std::{net::SocketAddr, sync::atomic::Ordering};

use chacha20poly1305::KeyInit;

use crate::{packet_memory::PacketRecollection, Router};

use super::*;

#[test]
fn construct() {
Router::new([0; 8], vec![]);
}

fn state<'c>(controller: &Controller) -> Arc<ConfiguredRouter> {
controller.router.state.load_full()
}

#[test]
fn crud_receive_tunnel() {
let mut router = Router::new([0; 8], vec![]);
let (mut controller, _) = router.handles(0);

controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()));
controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()));

assert!(state(&controller).recv_tunnels.contains_key(&[1; 8]));

controller.delete_receive_tunnel([1; 8]);

assert!(!state(&controller).recv_tunnels.contains_key(&[1; 8]));
}

#[test]
fn crud_send_tunnel() {
let mut router = Router::new([0; 8], vec![]);
let (mut controller, _) = router.handles(0);

let link = Link {
local: SocketAddr::from(([0, 0, 0, 0], 0)),
remote: SocketAddr::from(([0, 0, 0, 1], 1)),
};

controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![link]);

assert_eq!(state(&controller).send_tunnels[&[1; 8]].links, vec![link]);

controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![]);

assert_eq!(state(&controller).send_tunnels[&[1; 8]].links, vec![]);

controller.delete_send_tunnel([1; 8]);

assert!(!state(&controller).send_tunnels.contains_key(&[1; 8]));
}

#[test]
fn receive_updates_preserve_state() {
let mut router = Router::new([0; 8], vec![]);
let (mut controller, _) = router.handles(0);

controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()));

state(&controller).recv_tunnels[&[1; 8]].memory.observe(0);

controller.upsert_receive_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()));

assert_eq!(
state(&controller).recv_tunnels[&[1; 8]].memory.observe(0),
PacketRecollection::Seen
)
}

#[test]
fn send_updates_preserve_state() {
let mut router = Router::new([0; 8], vec![]);
let (mut controller, _) = router.handles(0);

controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[0; 32]).into()), vec![]);

state(&controller).send_tunnels[&[1; 8]]
.next_sequence_number
.store(1, Ordering::SeqCst);

controller.upsert_send_tunnel([1; 8], ChaCha20Poly1305::new((&[1; 32]).into()), vec![]);

assert_eq!(
state(&controller).send_tunnels[&[1; 8]]
.next_sequence_number
.load(Ordering::SeqCst),
1
);
}
}
Loading

0 comments on commit 9899160

Please sign in to comment.