Skip to content

Commit

Permalink
apply CR
Browse files Browse the repository at this point in the history
  • Loading branch information
dignifiedquire committed Jan 8, 2025
1 parent 0dbbdb5 commit 8f70daf
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 29 deletions.
60 changes: 48 additions & 12 deletions iroh-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::{

use anyhow::{bail, Context as _, Result};
use clap::Parser;
use futures_lite::FutureExt;
use iroh_base::NodeId;
use iroh_relay::{
defaults::{
Expand Down Expand Up @@ -176,17 +177,15 @@ struct Config {
access: AccessConfig,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
enum AccessConfig {
/// Allows everyone
#[serde(rename = "everyone")]
#[default]
Everyone,
/// Allows only these nodes.
#[serde(rename = "allowlist")]
Allowlist(Vec<NodeId>),
/// Allows everyone, except these nodes.
#[serde(rename = "denylist")]
Denylist(Vec<NodeId>),
}

Expand All @@ -195,21 +194,31 @@ impl From<AccessConfig> for iroh_relay::server::AccessConfig {
match cfg {
AccessConfig::Everyone => iroh_relay::server::AccessConfig::Everyone,
AccessConfig::Allowlist(allow_list) => {
let allow_list = Arc::new(allow_list);
iroh_relay::server::AccessConfig::Restricted(Box::new(move |node_id| {
if allow_list.contains(&node_id) {
iroh_relay::server::Access::Allow
} else {
iroh_relay::server::Access::Deny
let allow_list = allow_list.clone();
async move {
if allow_list.contains(&node_id) {
iroh_relay::server::Access::Allow
} else {
iroh_relay::server::Access::Deny
}
}
.boxed()
}))
}
AccessConfig::Denylist(deny_list) => {
let deny_list = Arc::new(deny_list);
iroh_relay::server::AccessConfig::Restricted(Box::new(move |node_id| {
if deny_list.contains(&node_id) {
iroh_relay::server::Access::Deny
} else {
iroh_relay::server::Access::Allow
let deny_list = deny_list.clone();
async move {
if deny_list.contains(&node_id) {
iroh_relay::server::Access::Deny
} else {
iroh_relay::server::Access::Allow
}
}
.boxed()
}))
}
}
Expand Down Expand Up @@ -686,6 +695,9 @@ mod metrics {
mod tests {
use std::num::NonZeroU32;

use iroh_base::SecretKey;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use testresult::TestResult;

use super::*;
Expand Down Expand Up @@ -723,4 +735,28 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_access_config() -> TestResult {
let config = "
access = \"everyone\"
";
let config = Config::from_str(config)?;
assert_eq!(config.access, AccessConfig::Everyone);

let mut rng = ChaCha8Rng::seed_from_u64(0);
let node_id = SecretKey::generate(&mut rng).public();

let config = format!(
"
access.allowlist = [
\"{node_id}\",
]
"
);
let config = Config::from_str(dbg!(&config))?;
assert_eq!(config.access, AccessConfig::Allowlist(vec![node_id]));

Ok(())
}
}
72 changes: 56 additions & 16 deletions iroh-relay/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{fmt, future::Future, net::SocketAddr, num::NonZeroU32, pin::Pin, sync:

use anyhow::{anyhow, bail, Context, Result};
use derive_more::Debug;
use futures_lite::StreamExt;
use futures_lite::{future::Boxed, StreamExt};
use http::{
response::Builder as ResponseBuilder, HeaderMap, Method, Request, Response, StatusCode,
};
Expand Down Expand Up @@ -132,16 +132,17 @@ pub enum AccessConfig {
Everyone,
/// Only nodes for which the function returns `Access::Allow`.
#[debug("restricted")]
Restricted(Box<dyn Fn(NodeId) -> Access + Send + Sync + 'static>),
Restricted(Box<dyn Fn(NodeId) -> Boxed<Access> + Send + Sync + 'static>),
}

impl AccessConfig {
/// Is this node allowed?
pub fn is_allowed(&self, node: NodeId) -> bool {
pub async fn is_allowed(&self, node: NodeId) -> bool {
match self {
Self::Everyone => true,
Self::Restricted(check) => {
matches!(check(node), Access::Allow)
let res = check(node).await;
matches!(res, Access::Allow)
}
}
}
Expand Down Expand Up @@ -807,6 +808,7 @@ mod tests {
use std::{net::Ipv4Addr, time::Duration};

use bytes::Bytes;
use futures_lite::FutureExt;
use futures_util::SinkExt;
use http::header::UPGRADE;
use iroh_base::{NodeId, SecretKey};
Expand Down Expand Up @@ -1145,7 +1147,7 @@ mod tests {
}

#[tokio::test]
async fn test_relay_access_reject() {
async fn test_relay_access_control() -> Result<()> {
let _guard = iroh_test::logging::setup();

let a_secret_key = SecretKey::generate(rand::thread_rng());
Expand All @@ -1158,13 +1160,16 @@ mod tests {
limits: Default::default(),
key_cache_capacity: Some(1024),
access: AccessConfig::Restricted(Box::new(move |node_id| {
info!("checking {}", node_id);
// reject node a
if node_id == a_key {
Access::Deny
} else {
Access::Allow
async move {
info!("checking {}", node_id);
// reject node a
if node_id == a_key {
Access::Deny
} else {
Access::Allow
}
}
.boxed()
})),
}),
quic: None,
Expand All @@ -1174,14 +1179,13 @@ mod tests {
.await
.unwrap();
let relay_url = format!("http://{}", server.http_addr().unwrap());
let relay_url: RelayUrl = relay_url.parse().unwrap();
let relay_url: RelayUrl = relay_url.parse()?;

// set up client a
let resolver = crate::dns::default_resolver().clone();
let mut client_a = ClientBuilder::new(relay_url.clone(), a_secret_key, resolver)
.connect()
.await
.unwrap();
.await?;

// the next message should be the rejection of the connection
tokio::time::timeout(Duration::from_millis(500), async move {
Expand All @@ -1194,7 +1198,43 @@ mod tests {
}
}
})
.await
.unwrap();
.await?;

// test that another client has access

// set up client b
let b_secret_key = SecretKey::generate(rand::thread_rng());
let b_key = b_secret_key.public();

let resolver = crate::dns::default_resolver().clone();
let mut client_b = ClientBuilder::new(relay_url.clone(), b_secret_key, resolver)
.connect()
.await?;

// set up client c
let c_secret_key = SecretKey::generate(rand::thread_rng());
let c_key = c_secret_key.public();

let resolver = crate::dns::default_resolver().clone();
let mut client_c = ClientBuilder::new(relay_url.clone(), c_secret_key, resolver)
.connect()
.await?;

// send message from b to c
let msg = Bytes::from("hello, c");
let res = try_send_recv(&mut client_b, &mut client_c, c_key, msg.clone()).await?;

if let ReceivedMessage::ReceivedPacket {
remote_node_id,
data,
} = res
{
assert_eq!(b_key, remote_node_id);
assert_eq!(msg, data);
} else {
panic!("client_c received unexpected message {res:?}");
}

Ok(())
}
}
2 changes: 1 addition & 1 deletion iroh-relay/src/server/http_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ impl Inner {
.context("unable to receive client information")?;

trace!("accept: checking access: {:?}", self.access);
if !self.access.is_allowed(client_key) {
if !self.access.is_allowed(client_key).await {
io.send(Frame::Health {
problem: Bytes::from_static(b"not authenticated"),
})
Expand Down

0 comments on commit 8f70daf

Please sign in to comment.