Skip to content

Commit

Permalink
pre-lookup addresses and use static resolver
Browse files Browse the repository at this point in the history
  • Loading branch information
eaypek-tfh committed Jan 12, 2025
1 parent 5afaecc commit 05d4b8b
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 27 deletions.
2 changes: 1 addition & 1 deletion iris-mpc-store/src/s3_importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const SINGLE_ELEMENT_SIZE: usize = IRIS_CODE_LENGTH * mem::size_of::<u16>() * 2
+ MASK_CODE_LENGTH * mem::size_of::<u16>() * 2
+ mem::size_of::<u32>(); // 75 KB

const MAX_RANGE_SIZE: usize = 100; // Download chunks in sub-chunks of 100 elements = 7.5 MB
const MAX_RANGE_SIZE: usize = 200; // Download chunks in sub-chunks of 100 elements = 7.5 MB

pub struct S3StoredIris {
#[allow(dead_code)]
Expand Down
88 changes: 62 additions & 26 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use aws_sdk_s3::{config::Builder as S3ConfigBuilder, Client as S3Client};
use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient};
use aws_sdk_sqs::{config::Region, Client};
use aws_smithy_experimental::hyper_1_0::{CryptoMode, HyperClientBuilder};
use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns, ResolveDnsError};
use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns};
use axum::{response::IntoResponse, routing::get, Router};
use clap::Parser;
use eyre::{eyre, Context};
Expand Down Expand Up @@ -60,7 +60,7 @@ use std::{
net::IpAddr,
panic,
sync::{
atomic::{AtomicBool, Ordering},
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, LazyLock, Mutex,
},
time,
Expand Down Expand Up @@ -672,66 +672,101 @@ async fn main() -> eyre::Result<()> {
}

struct StaticResolver {
resolver: Arc<TokioAsyncResolver>,
ips: Arc<Vec<IpAddr>>,
current: Arc<AtomicUsize>,
}

impl StaticResolver {
fn new() -> Self {
let resolver =
TokioAsyncResolver::tokio(ResolverConfig::default(), ResolverOpts::default());
StaticResolver {
resolver: Arc::new(resolver),
fn new(ips: Vec<IpAddr>) -> Self {
assert!(
!ips.is_empty(),
"StaticResolver requires at least one IP address."
);
Self {
ips: Arc::new(ips),
current: Arc::new(AtomicUsize::new(0)),
}
}
}

impl Debug for StaticResolver {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
// Load the current index atomically
let current_index = self.current.load(Ordering::SeqCst);
f.debug_struct("StaticResolver")
.field("resolver", &"Arc<Resolver>")
.field("ips", &self.ips)
.field("current_index", &current_index)
.finish()
}
}

impl Clone for StaticResolver {
fn clone(&self) -> Self {
StaticResolver {
resolver: Arc::clone(&self.resolver),
Self {
ips: Arc::clone(&self.ips),
current: Arc::clone(&self.current),
}
}
}

impl ResolveDns for StaticResolver {
fn resolve_dns<'a>(&'a self, _name: &'a str) -> DnsFuture<'a> {
let resolver = Arc::clone(&self.resolver);
let hostname = _name.to_string();
// Clone the Arc to ensure the IP list remains accessible within the async block
let ips = Arc::clone(&self.ips);
let current_index = self.current.fetch_add(1, Ordering::SeqCst);
let index = current_index % ips.len();
let selected_ip = ips[index];

// Create the async block that performs DNS resolution
// Create the async block that returns the selected IP
let future = async move {
match resolver.lookup_ip(&hostname).await {
Ok(lookup_result) => {
let ips: Vec<IpAddr> = lookup_result.iter().collect();
tracing::info!("Resolved host {} to {:?}", hostname, ips);
Ok(ips)
}
Err(e) => Err(ResolveDnsError::new(format!(
"Failed to resolve {}: {}",
hostname, e
))),
}
tracing::info!("Returning IP {:?} for host {}", selected_ip, _name);
Ok(vec![selected_ip])
};

// Wrap the future into DnsFuture
DnsFuture::new(Box::pin(future))
}
}

async fn resolve_export_bucket_ips(host: String) -> eyre::Result<Vec<IpAddr>> {
let mut all_ips = vec![];
let mut resolver_opts = ResolverOpts::default();
resolver_opts.positive_max_ttl = Some(time::Duration::from_millis(10));
let resolver = TokioAsyncResolver::tokio(ResolverConfig::default(), resolver_opts);
loop {
// Check if we've collected enough unique IPs
if all_ips.len() >= 10 {
break;
}
match resolver.lookup_ip(&host).await {
Ok(lookup_result) => {
let ips: Vec<IpAddr> = lookup_result.iter().collect();
tracing::info!("Resolved {:?} for host {}", ips, host);
for ip in lookup_result.iter() {
// Attempt to insert the IP into the HashSet
if !all_ips.contains(&ip) {
all_ips.push(ip);
tracing::info!("Added IP {:?} for host {}", ip, host);
}
}
}
Err(e) => {
tracing::error!("Failed to resolve host {}: {}", host, e);
}
}
tokio::time::sleep(Duration::from_millis(20)).await;
}
Ok(all_ips)
}

async fn server_main(config: Config) -> eyre::Result<()> {
let shutdown_handler = Arc::new(ShutdownHandler::new(
config.shutdown_last_results_sync_timeout_secs,
));
shutdown_handler.wait_for_shutdown_signal().await;

let shares_bucket_host = format!("{}.s3.{}.amazonaws.com", config.shares_bucket_name, REGION);
let shares_bucket_ips = resolve_export_bucket_ips(shares_bucket_host);
// Load batch_size config
*CURRENT_BATCH_SIZE.lock().unwrap() = config.max_batch_size;
let max_sync_lookback: usize = config.max_batch_size * 2;
Expand All @@ -753,7 +788,8 @@ async fn server_main(config: Config) -> eyre::Result<()> {
// Increase S3 retries to 5
// let resolver = Resolver::new(ResolverConfig::default(),
// ResolverOpts::default()).unwrap();
let static_resolver = StaticResolver::new();

let static_resolver = StaticResolver::new(shares_bucket_ips.await?);
let client = HyperClientBuilder::new()
.crypto_mode(CryptoMode::Ring)
.build_with_resolver(static_resolver);
Expand Down

0 comments on commit 05d4b8b

Please sign in to comment.