Skip to content

Commit

Permalink
use channel instead of stream
Browse files Browse the repository at this point in the history
  • Loading branch information
eaypek-tfh committed Jan 21, 2025
1 parent ffbd418 commit 7eb2f66
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 120 deletions.
2 changes: 1 addition & 1 deletion deploy/stage/common-values-iris-mpc.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
image: "ghcr.io/worldcoin/iris-mpc:982558eac4514ed755b7d7e227bf8dfde334f5ae"
image: "ghcr.io/worldcoin/iris-mpc:258847dd18df529f5e7043b2377af55432a553a3"

environment: stage
replicaCount: 1
Expand Down
2 changes: 1 addition & 1 deletion deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ env:
value: "even_odd_binary_output_16k"

- name: SMPC__LOAD_CHUNKS_PARALLELISM
value: "256"
value: "64"

- name: SMPC__CLEAR_DB_BEFORE_INIT
value: "true"
Expand Down
2 changes: 1 addition & 1 deletion deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ env:
value: "even_odd_binary_output_16k"

- name: SMPC__LOAD_CHUNKS_PARALLELISM
value: "256"
value: "64"

- name: SMPC__CLEAR_DB_BEFORE_INIT
value: "true"
Expand Down
2 changes: 1 addition & 1 deletion deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ env:
value: "even_odd_binary_output_16k"

- name: SMPC__LOAD_CHUNKS_PARALLELISM
value: "256"
value: "64"

- name: SMPC__CLEAR_DB_BEFORE_INIT
value: "true"
Expand Down
143 changes: 63 additions & 80 deletions iris-mpc-store/src/s3_importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,9 @@ use crate::StoredIris;
use async_trait::async_trait;
use aws_sdk_s3::{primitives::ByteStream, Client};
use eyre::eyre;
use futures::{stream, Stream, StreamExt};
use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH};
use std::{
mem,
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Instant,
};
use tokio::{io::AsyncReadExt, task};
use std::{mem, sync::Arc};
use tokio::{io::AsyncReadExt, sync::mpsc::Sender, task};

const SINGLE_ELEMENT_SIZE: usize = IRIS_CODE_LENGTH * mem::size_of::<u16>() * 2
+ MASK_CODE_LENGTH * mem::size_of::<u16>() * 2
Expand Down Expand Up @@ -232,85 +223,74 @@ pub async fn last_snapshot_timestamp(
}

pub async fn fetch_and_parse_chunks(
store: &impl ObjectStore,
store: Arc<impl ObjectStore>,
concurrency: usize,
prefix_name: String,
last_snapshot_details: LastSnapshotDetails,
) -> Pin<Box<dyn Stream<Item = eyre::Result<StoredIris>> + Send + '_>> {
tx: Sender<StoredIris>,
) -> eyre::Result<()> {
tracing::info!("Generating chunk files using: {:?}", last_snapshot_details);
let range_size = if last_snapshot_details.chunk_size as usize > MAX_RANGE_SIZE {
MAX_RANGE_SIZE
} else {
last_snapshot_details.chunk_size as usize
};
let total_bytes = Arc::new(AtomicUsize::new(0));
let now = Instant::now();

let result_stream =
stream::iter((1..=last_snapshot_details.last_serial_id).step_by(range_size))
.map({
let total_bytes_clone = total_bytes.clone();
move |chunk| {
let counter = total_bytes_clone.clone();
let prefix_name = prefix_name.clone();
async move {
let chunk_id = (chunk / last_snapshot_details.chunk_size)
* last_snapshot_details.chunk_size
+ 1;
let offset_within_chunk = (chunk - chunk_id) as usize;
let mut object_stream = store
.get_object(
&format!("{}/{}.bin", prefix_name, chunk_id),
(
offset_within_chunk * SINGLE_ELEMENT_SIZE,
(offset_within_chunk + range_size) * SINGLE_ELEMENT_SIZE,
),
)
.await?
.into_async_read();
let mut records = Vec::with_capacity(range_size);
let mut buf = vec![0u8; SINGLE_ELEMENT_SIZE];
loop {
match object_stream.read_exact(&mut buf).await {
Ok(_) => {
let iris = S3StoredIris::from_bytes(&buf);
records.push(iris);
counter.fetch_add(SINGLE_ELEMENT_SIZE, Ordering::Relaxed);
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
}
let stream_of_stored_iris =
stream::iter(records).map(|res_s3| res_s3.map(StoredIris::S3));
let mut handles: Vec<task::JoinHandle<Result<(), eyre::Error>>> = Vec::new();
let mut active_handles = 0;

Ok::<_, eyre::Error>(stream_of_stored_iris)
}
}
})
.buffer_unordered(concurrency)
.flat_map(|result| match result {
Ok(stream) => stream.boxed(),
Err(e) => stream::once(async move { Err(e) }).boxed(),
})
.inspect({
let counter = Arc::new(AtomicUsize::new(0));
move |_| {
if counter.fetch_add(1, Ordering::Relaxed) % 1_000_000 == 0 {
let elapsed = now.elapsed().as_secs_f32();
if elapsed > 0.0 {
let bytes = total_bytes.load(Ordering::Relaxed);
tracing::info!(
"Current download throughput: {:.2} Gbps",
bytes as f32 * 8.0 / 1e9 / elapsed
);
for chunk in (1..=last_snapshot_details.last_serial_id).step_by(range_size) {
let chunk_id =
(chunk / last_snapshot_details.chunk_size) * last_snapshot_details.chunk_size + 1;
let prefix_name = prefix_name.clone();
let offset_within_chunk = (chunk - chunk_id) as usize;

// Wait if we've hit the concurrency limit
if active_handles >= concurrency {
let handle = handles.remove(0);
handle.await??;
active_handles -= 1;
}

handles.push(task::spawn({
let store = Arc::clone(&store);
let mut slice = vec![0u8; SINGLE_ELEMENT_SIZE];
let tx = tx.clone();
async move {
let mut result = store
.get_object(
&format!("{}/{}.bin", prefix_name, chunk_id),
(
offset_within_chunk * SINGLE_ELEMENT_SIZE,
(offset_within_chunk + range_size) * SINGLE_ELEMENT_SIZE,
),
)
.await?
.into_async_read();

loop {
match result.read_exact(&mut slice).await {
Ok(_) => {
let iris = S3StoredIris::from_bytes(&slice)?;
tx.send(StoredIris::S3(iris)).await?;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(e.into()),
}
}
})
.boxed();

result_stream
Ok(())
}
}));
}

drop(tx);

// Wait for remaining handles
for handle in handles {
handle.await??;
}

Ok(())
}

pub async fn fetch_to_memory(
Expand Down Expand Up @@ -385,6 +365,7 @@ mod tests {
use aws_sdk_s3::primitives::SdkBody;
use rand::Rng;
use std::{cmp::min, collections::HashSet};
use tokio::sync::mpsc;

#[derive(Default, Clone)]
pub struct MockStore {
Expand Down Expand Up @@ -488,12 +469,14 @@ mod tests {
last_serial_id: MOCK_ENTRIES as i64,
chunk_size: MOCK_CHUNK_SIZE as i64,
};
let mut chunks =
fetch_and_parse_chunks(&store, 1, "out".to_string(), last_snapshot_details).await;
let (tx, mut rx) = mpsc::channel::<StoredIris>(1024);
let store_arc = Arc::new(store);
let _res =
fetch_and_parse_chunks(store_arc, 1, "out".to_string(), last_snapshot_details, tx)
.await;
let mut count = 0;
let mut ids: HashSet<usize> = HashSet::from_iter(1..MOCK_ENTRIES);
while let Some(chunk) = chunks.next().await {
let chunk = chunk.unwrap();
while let Some(chunk) = rx.recv().await {
ids.remove(&(chunk.index()));
count += 1;
}
Expand Down
59 changes: 23 additions & 36 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns};
use axum::{response::IntoResponse, routing::get, Router};
use clap::Parser;
use eyre::{eyre, Context};
use futures::{stream::select_all, StreamExt, TryStreamExt};
use hickory_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
Expand Down Expand Up @@ -63,6 +62,7 @@ use std::{
mem,
net::IpAddr,
panic,
process::exit,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, LazyLock, Mutex,
Expand Down Expand Up @@ -1080,6 +1080,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let load_chunks_parallelism = config.load_chunks_parallelism;
let db_chunks_bucket_name = config.db_chunks_bucket_name.clone();
let db_chunks_folder_name = config.db_chunks_folder_name.clone();
let env = config.environment.clone();

let (tx, rx) = oneshot::channel();
background_tasks.spawn_blocking(move || {
Expand Down Expand Up @@ -1185,40 +1186,21 @@ async fn server_main(config: Config) -> eyre::Result<()> {
);

let s3_store = S3Store::new(db_chunks_s3_client, db_chunks_bucket_name);

let mut stream = match config.enable_s3_importer {
true => {
tracing::info!("S3 importer enabled. Fetching from s3 + db");
let min_last_modified_at = last_snapshot_details.timestamp
- config.db_load_safety_overlap_seconds;
tracing::info!(
"Last snapshot timestamp: {}, min_last_modified_at: {}",
last_snapshot_details.timestamp,
min_last_modified_at
);
let stream_s3 = fetch_and_parse_chunks(
&s3_store,
load_chunks_parallelism,
db_chunks_folder_name,
last_snapshot_details,
)
.await
.boxed();

let stream_db = store
.stream_irises_par(Some(min_last_modified_at), parallelism)
.await
.boxed();

select_all(vec![stream_s3, stream_db])
}
false => {
tracing::info!("S3 importer disabled. Fetching only from db");
let stream_db =
store.stream_irises_par(None, parallelism).await.boxed();
select_all(vec![stream_db])
}
};
let s3_arc = Arc::new(s3_store);

let (tx, mut rx) = mpsc::channel::<StoredIris>(1024);

tokio::spawn(async move {
fetch_and_parse_chunks(
s3_arc,
load_chunks_parallelism,
db_chunks_folder_name,
last_snapshot_details,
tx.clone(),
)
.await
.expect("Couldn't fetch and parse chunks");
});

tracing::info!("Page-lock host memory");
let left_codes = actor.left_code_db_slices.code_gr.clone();
Expand Down Expand Up @@ -1256,7 +1238,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let mut serial_ids_from_db: HashSet<i64> = HashSet::new();
let mut n_loaded_from_db = 0;
let mut n_loaded_from_s3 = 0;
while let Some(result) = stream.try_next().await? {
while let Some(result) = rx.recv().await {
time_waiting_for_stream += now_load_summary.elapsed();
now_load_summary = Instant::now();
let index = result.index();
Expand Down Expand Up @@ -1332,6 +1314,11 @@ async fn server_main(config: Config) -> eyre::Result<()> {
time_loading_into_memory,
);

if env.eq("stage") {
tracing::info!("Test environment detected, exiting");
exit(0);
}

// Clear the memory allocated by temp HashSet
serial_ids_from_db.clear();
serial_ids_from_db.shrink_to_fit();
Expand Down

0 comments on commit 7eb2f66

Please sign in to comment.