Skip to content

Commit

Permalink
add timeout config and avoid wrapping s3 record
Browse files Browse the repository at this point in the history
  • Loading branch information
eaypek-tfh committed Jan 23, 2025
1 parent 83a7386 commit 57edaa7
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 59 deletions.
2 changes: 1 addition & 1 deletion deploy/stage/common-values-iris-mpc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc:f6b975ac741268aa143c308022ce2a8011c8b95b"
image: "ghcr.io/worldcoin/iris-mpc:dcc0ab9b31e3e0bcb3207b8a597c3e725336c6db"

environment: stage
replicaCount: 1
Expand Down
2 changes: 1 addition & 1 deletion iris-mpc-store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

mod s3_importer;

use crate::s3_importer::S3StoredIris;
use bytemuck::cast_slice;
use eyre::{eyre, Result};
use futures::{
Expand All @@ -17,6 +16,7 @@ use iris_mpc_common::{
use rand::{rngs::StdRng, Rng, SeedableRng};
pub use s3_importer::{
fetch_and_parse_chunks, fetch_to_memory, last_snapshot_timestamp, ObjectStore, S3Store,
S3StoredIris,
};
use sqlx::{
migrate::Migrator, postgres::PgPoolOptions, Executor, PgPool, Postgres, Row, Transaction,
Expand Down
20 changes: 9 additions & 11 deletions iris-mpc-store/src/s3_importer.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use crate::StoredIris;
use async_trait::async_trait;
use aws_sdk_s3::{primitives::ByteStream, Client};
use eyre::eyre;
use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH};
use std::{mem, sync::Arc};
use std::{collections::VecDeque, mem, sync::Arc};
use tokio::{io::AsyncReadExt, sync::mpsc::Sender, task};

const SINGLE_ELEMENT_SIZE: usize = IRIS_CODE_LENGTH * mem::size_of::<u16>() * 2
Expand Down Expand Up @@ -227,16 +226,16 @@ pub async fn fetch_and_parse_chunks(
concurrency: usize,
prefix_name: String,
last_snapshot_details: LastSnapshotDetails,
tx: Sender<StoredIris>,
tx: Sender<S3StoredIris>,
) -> eyre::Result<()> {
tracing::info!("Generating chunk files using: {:?}", last_snapshot_details);
let range_size = if last_snapshot_details.chunk_size as usize > MAX_RANGE_SIZE {
MAX_RANGE_SIZE
} else {
last_snapshot_details.chunk_size as usize
};
let mut handles: Vec<task::JoinHandle<Result<(), eyre::Error>>> = Vec::new();
let mut active_handles = 0;
let mut handles: VecDeque<task::JoinHandle<Result<(), eyre::Error>>> =
VecDeque::with_capacity(concurrency);

for chunk in (1..=last_snapshot_details.last_serial_id).step_by(range_size) {
let chunk_id =
Expand All @@ -245,13 +244,12 @@ pub async fn fetch_and_parse_chunks(
let offset_within_chunk = (chunk - chunk_id) as usize;

// Wait if we've hit the concurrency limit
if active_handles >= concurrency {
let handle = handles.remove(0);
if handles.len() >= concurrency {
let handle = handles.pop_front().expect("No s3 import handles to pop");
handle.await??;
active_handles -= 1;
}

handles.push(task::spawn({
handles.push_back(task::spawn({
let store = Arc::clone(&store);
let mut slice = vec![0u8; SINGLE_ELEMENT_SIZE];
let tx = tx.clone();
Expand All @@ -271,7 +269,7 @@ pub async fn fetch_and_parse_chunks(
match result.read_exact(&mut slice).await {
Ok(_) => {
let iris = S3StoredIris::from_bytes(&slice)?;
tx.send(StoredIris::S3(iris)).await?;
tx.send(iris).await?;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
Expand Down Expand Up @@ -469,7 +467,7 @@ mod tests {
last_serial_id: MOCK_ENTRIES as i64,
chunk_size: MOCK_CHUNK_SIZE as i64,
};
let (tx, mut rx) = mpsc::channel::<StoredIris>(1024);
let (tx, mut rx) = mpsc::channel::<S3StoredIris>(1024);
let store_arc = Arc::new(store);
let _res =
fetch_and_parse_chunks(store_arc, 1, "out".to_string(), last_snapshot_details, tx)
Expand Down
69 changes: 23 additions & 46 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::needless_range_loop)]

use aws_config::retry::RetryConfig;
use aws_config::{retry::RetryConfig, timeout::TimeoutConfig};
use aws_sdk_s3::{
config::{Builder as S3ConfigBuilder, StalledStreamProtectionConfig},
Client as S3Client,
Expand Down Expand Up @@ -49,7 +49,7 @@ use iris_mpc_gpu::{
},
};
use iris_mpc_store::{
fetch_and_parse_chunks, fetch_to_memory, last_snapshot_timestamp, S3Store, Store, StoredIris,
fetch_and_parse_chunks, fetch_to_memory, last_snapshot_timestamp, S3Store, S3StoredIris, Store,
StoredIrisRef,
};
use metrics_exporter_statsd::StatsdBuilder;
Expand Down Expand Up @@ -788,6 +788,9 @@ async fn server_main(config: Config) -> eyre::Result<()> {

// Increase S3 retries to 5
let retry_config = RetryConfig::standard().with_max_attempts(5);
let timeout_config = TimeoutConfig::builder()
.connect_timeout(Duration::from_secs(10))
.build();

let s3_config = S3ConfigBuilder::from(&shared_config)
.retry_config(retry_config.clone())
Expand All @@ -798,6 +801,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
// disable stalled stream protection to avoid panics during s3 import
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.retry_config(retry_config)
.timeout_config(timeout_config)
.build()
}
false => {
Expand All @@ -811,6 +815,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
// disable stalled stream protection to avoid panics during s3 import
.stalled_stream_protection(StalledStreamProtectionConfig::disabled())
.retry_config(retry_config)
.timeout_config(timeout_config)
.http_client(http_client)
.build()
}
Expand Down Expand Up @@ -1203,7 +1208,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let s3_arc = Arc::new(s3_store);

let (tx, mut rx) =
mpsc::channel::<StoredIris>(config.load_chunks_buffer_size);
mpsc::channel::<S3StoredIris>(config.load_chunks_buffer_size);

tokio::spawn(async move {
fetch_and_parse_chunks(
Expand Down Expand Up @@ -1250,52 +1255,28 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let mut record_counter = 0;
let mut all_serial_ids: HashSet<i64> =
HashSet::from_iter(1..=(store_len as i64));
let mut serial_ids_from_db: HashSet<i64> = HashSet::new();
let mut n_loaded_from_db = 0;
let n_loaded_from_db = 0;
let mut n_loaded_from_s3 = 0;
while let Some(result) = rx.recv().await {
while let Some(iris) = rx.recv().await {
time_waiting_for_stream += now_load_summary.elapsed();
now_load_summary = Instant::now();
let index = result.index();
let index = iris.index();
if index == 0 || index > store_len {
tracing::error!("Invalid iris index {}", index);
return Err(eyre!("Invalid iris index {}", index));
}
match result {
StoredIris::DB(iris) => {
n_loaded_from_db += 1;
serial_ids_from_db.insert(iris.id());
actor.load_single_record_from_db(
iris.index() - 1,
iris.left_code(),
iris.left_mask(),
iris.right_code(),
iris.right_mask(),
);
}
StoredIris::S3(iris) => {
if serial_ids_from_db.contains(&iris.id()) {
tracing::warn!(
"Skip overriding record already loaded via DB with S3 \
record: {}",
iris.index()
);
continue;
}
n_loaded_from_s3 += 1;
actor.load_single_record_from_s3(
iris.index() - 1,
iris.left_code_odd(),
iris.left_code_even(),
iris.right_code_odd(),
iris.right_code_even(),
iris.left_mask_odd(),
iris.left_mask_even(),
iris.right_mask_odd(),
iris.right_mask_even(),
);
}
};
n_loaded_from_s3 += 1;
actor.load_single_record_from_s3(
iris.index() - 1,
iris.left_code_odd(),
iris.left_code_even(),
iris.right_code_odd(),
iris.right_code_even(),
iris.left_mask_odd(),
iris.left_mask_even(),
iris.right_mask_odd(),
iris.right_mask_even(),
);

if record_counter % 100_000 == 0 {
let elapsed = now.elapsed();
Expand Down Expand Up @@ -1334,10 +1315,6 @@ async fn server_main(config: Config) -> eyre::Result<()> {
exit(0);
}

// Clear the memory allocated by temp HashSet
serial_ids_from_db.clear();
serial_ids_from_db.shrink_to_fit();

if !all_serial_ids.is_empty() {
tracing::error!("Not all serial_ids were loaded: {:?}", all_serial_ids);
return Err(eyre!(
Expand Down

0 comments on commit 57edaa7

Please sign in to comment.