From 9acbef618556990a3d4d92dade3aa9589f1b6b7b Mon Sep 17 00:00:00 2001 From: Max Niederman Date: Fri, 19 Jan 2024 18:17:48 -0800 Subject: [PATCH] feat(worker): implement event loop --- Cargo.lock | 2 + packages/centipede_router/src/worker.rs | 1 - packages/centipede_worker/Cargo.toml | 4 +- packages/centipede_worker/src/lib.rs | 183 +++++++++++++++++++++-- packages/centipede_worker/src/sockets.rs | 21 ++- 5 files changed, 191 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6cc1e98..8636dc8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -71,8 +71,10 @@ dependencies = [ name = "centipede_worker" version = "0.1.0" dependencies = [ + "centipede_proto", "centipede_router", "hypertube", + "log", "mio", "socket2", "thiserror", diff --git a/packages/centipede_router/src/worker.rs b/packages/centipede_router/src/worker.rs index 4098b07..2acebe4 100644 --- a/packages/centipede_router/src/worker.rs +++ b/packages/centipede_router/src/worker.rs @@ -1,4 +1,3 @@ -use arc_swap::access::Access; use centipede_proto::{ marker::{auth, text}, PacketMessage, diff --git a/packages/centipede_worker/Cargo.toml b/packages/centipede_worker/Cargo.toml index 91634b6..3d88b23 100644 --- a/packages/centipede_worker/Cargo.toml +++ b/packages/centipede_worker/Cargo.toml @@ -6,8 +6,10 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +centipede_proto = { version = "0.1.0", path = "../centipede_proto" } centipede_router = { version = "0.1.0", path = "../centipede_router" } hypertube = "0.2.2" -mio = { version = "0.8.10", features = ["os-poll"] } +log = "0.4.20" +mio = { version = "0.8.10", features = ["os-poll", "os-ext"] } socket2 = { version = "0.5.5", features = ["all"] } thiserror = "1.0.56" diff --git a/packages/centipede_worker/src/lib.rs b/packages/centipede_worker/src/lib.rs index b9678d6..1cbc84f 100644 --- a/packages/centipede_worker/src/lib.rs +++ b/packages/centipede_worker/src/lib.rs @@ -1,7 +1,20 @@ -use std::sync::atomic::AtomicBool; +#![feature(maybe_uninit_uninit_array)] +#![feature(maybe_uninit_slice)] -use centipede_router::worker::WorkerHandle; +use std::{ + io, + mem::{self, MaybeUninit}, + os::fd::AsRawFd, + sync::atomic::{AtomicBool, Ordering}, + task::Poll, + time::Duration, +}; + +use centipede_proto::{MessageDiscriminant, PacketMessage}; +use centipede_router::worker::{ConfigChanged, WorkerHandle}; +use mio::unix::SourceFd; use sockets::Sockets; +use thiserror::Error; mod sockets; @@ -10,34 +23,184 @@ pub struct Worker<'r> { /// The underlying handle to the router. router_handle: WorkerHandle<'r>, + /// The TUN queue. + tun_queue: hypertube::Queue<'r, false>, + /// Sockets in use by the worker. sockets: Sockets, /// A [`mio::Poll`] instance to use for polling sockets. poll: mio::Poll, - - /// A buffer of events to handle. - events: mio::Events, } impl<'r> Worker<'r> { /// Create a new worker. - pub fn new(router_handle: WorkerHandle<'r>) -> Self { + pub fn new(router_handle: WorkerHandle<'r>, tun_queue: hypertube::Queue<'r, false>) -> Self { Self { router_handle, + tun_queue, sockets: Sockets::new(), poll: mio::Poll::new().unwrap(), - events: mio::Events::with_capacity(1024), } } /// Wait for at least one event and handle it. - pub fn wait_and_handle(&mut self) { - todo!() + /// + /// Mutably borrows an event buffer for scratch space, to avoid reallocating it. + pub fn wait_and_handle(&mut self, events: &mut mio::Events) -> Result<(), Error> { + if let Some(change) = self.router_handle.check_config() { + self.handle_config_change(change)?; + } + + events.clear(); + self.poll + .poll(events, Some(Duration::from_secs(1))) + .map_err(Error::Poll)?; + + for event in &*events { + match event.token() { + // FIXME: ensure one event source cannot starve the others + TUN_TOKEN => self.handle_tun_readable()?, + mio::Token(idx) => self.handle_socket_readable(idx)?, + } + } + + Ok(()) } /// Handle events repeatedly until a shutdown is requested. - pub fn handle_until(&mut self, shutdown: &AtomicBool) { + pub fn handle_until(&mut self, shutdown: &AtomicBool) -> Result<(), Error> { + let mut events_scratch = mio::Events::with_capacity(1024); + + loop { + if shutdown.load(Ordering::Relaxed) { + break; + } + + self.wait_and_handle(&mut events_scratch)?; + } + + Ok(()) + } + + /// Handle a configuration change. + fn handle_config_change(&mut self, change: ConfigChanged) -> Result<(), Error> { + self.sockets + .update(change.recv_addrs().chain(change.send_addrs()))?; + + for (i, _) in change.recv_addrs().enumerate() { + self.poll + .registry() + .register( + &mut SourceFd(&self.sockets.resolve_index(i).unwrap().as_raw_fd()), + mio::Token(i), + mio::Interest::READABLE, + ) + .unwrap(); + } + + Ok(()) + } + + /// Handle a readable event on the TUN device. + fn handle_tun_readable(&mut self) -> Result<(), Error> { + // TODO: optimize this + let mut read_buf = [0; PACKET_BUFFER_SIZE]; + let mut write_buf = vec![0; PACKET_BUFFER_SIZE]; + + while let Poll::Ready(n) = self.tun_queue.read(&mut read_buf).map_err(Error::ReadTun)? { + let buf = &mut read_buf[..n]; + + let mut obligations = self.router_handle.handle_outgoing(buf); + + while let Some(obligation) = obligations.resume(mem::take(&mut write_buf)) { + let socket = self + .sockets + .resolve_or_bind_local_addr(obligation.link().local)?; + + socket + .send_to( + obligation.message().as_buffer(), + &obligation.link().remote.into(), + ) + .map_err(Error::WriteSocket)?; + + write_buf = obligation.fulfill(); + } + } + + Ok(()) + } + + /// Handle a readable event on a socket. + fn handle_socket_readable(&mut self, idx: usize) -> Result<(), Error> { + let socket = self + .sockets + .resolve_index(idx) + .expect("invalid socket index"); + + let mut buf: [MaybeUninit; PACKET_BUFFER_SIZE] = MaybeUninit::uninit_array(); + + loop { + match socket.recv(&mut buf) { + Ok(n) => { + // SAFETY: we just read `n` bytes into the buffer. + let msg = unsafe { MaybeUninit::slice_assume_init_mut(&mut buf[..n]) }; + + match centipede_proto::discriminate(&*msg) { + Ok(MessageDiscriminant::Control) => todo!(), + Ok(MessageDiscriminant::Packet) => { + let packet = match PacketMessage::from_buffer(msg) { + Ok(packet) => packet, + Err(e) => { + log::warn!("failed to parse packet message: {}", e); + continue; + } + }; + + if let Some(obligation) = self.router_handle.handle_incoming(packet) { + // TODO: ensure writes complete + self.tun_queue + .write(obligation.packet()) + .map_err(Error::WriteTun)?; + } + } + Err(e) => { + log::warn!("failed to parse message: {}", e); + continue; + } + } + } + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => break, + Err(e) => return Err(Error::ReadSocket(e))?, + } + } + todo!() } } + +const TUN_TOKEN: mio::Token = mio::Token(usize::MAX); + +const PACKET_BUFFER_SIZE: usize = 65536; + +#[derive(Debug, Error)] +pub enum Error { + #[error("failed to poll for events")] + Poll(#[from] io::Error), + + #[error(transparent)] + Sockets(#[from] sockets::SocketsError), + + #[error("failed to read from TUN device")] + ReadTun(#[source] io::Error), + + #[error("failed to write to UDP socket")] + WriteSocket(#[source] io::Error), + + #[error("failed to read from UDP socket")] + ReadSocket(#[source] io::Error), + + #[error("failed to write to TUN device")] + WriteTun(#[source] io::Error), +} diff --git a/packages/centipede_worker/src/sockets.rs b/packages/centipede_worker/src/sockets.rs index adf2254..820b0dd 100644 --- a/packages/centipede_worker/src/sockets.rs +++ b/packages/centipede_worker/src/sockets.rs @@ -51,10 +51,15 @@ impl Sockets { .try_clone() .map_err(SocketsError::DuplicateSocketFd)? } - None => { - stats.opened += 1; - bind_socket(addr).map_err(SocketsError::BindSocket)? - } + None => match self.by_local_addr.get(&addr) { + Some(&index) => self.arena[index] + .try_clone() + .map_err(SocketsError::DuplicateSocketFd)?, + None => { + stats.opened += 1; + bind_socket(addr).map_err(SocketsError::BindSocket)? + } + }, }; let index = self.arena.len(); @@ -72,7 +77,7 @@ impl Sockets { pub fn resolve_or_bind_local_addr( &mut self, addr: SocketAddr, - ) -> Result<&Socket, SocketsError> { + ) -> Result<&mut Socket, SocketsError> { let index = match self.by_local_addr.get(&addr) { Some(&index) => index, None => { @@ -82,12 +87,12 @@ impl Sockets { index } }; - Ok(&self.arena[index]) + Ok(&mut self.arena[index]) } /// Resolve an index to a socket. - pub fn resolve_index(&mut self, index: usize) -> Option<&Socket> { - self.arena.get(index) + pub fn resolve_index(&mut self, index: usize) -> Option<&mut Socket> { + self.arena.get_mut(index) } }