Skip to content

Commit

Permalink
Add graceful shutdown during DB loading
Browse files Browse the repository at this point in the history
  • Loading branch information
danielle-tfh committed Jan 21, 2025
1 parent f4dbe47 commit b952887
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/temp-branch-build-and-push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Branch - Build and push docker image
on:
push:
branches:
- "add-error-handling-skipped-requests"
- "add-graceful-shutdown-during-s3-sync"

concurrency:
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
Expand Down
66 changes: 62 additions & 4 deletions iris-mpc-store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ use futures::{
use iris_mpc_common::{
config::Config,
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
helpers::shutdown_handler::ShutdownHandler,
iris_db::iris::IrisCode,
};
use rand::{rngs::StdRng, Rng, SeedableRng};
pub use s3_importer::{fetch_and_parse_chunks, last_snapshot_timestamp, ObjectStore, S3Store};
use sqlx::{
migrate::Migrator, postgres::PgPoolOptions, Executor, PgPool, Postgres, Row, Transaction,
};
use std::ops::DerefMut;
use std::{ops::DerefMut, sync::Arc};

const APP_NAME: &str = "SMPC";
const MAX_CONNECTIONS: u32 = 100;
Expand Down Expand Up @@ -186,12 +187,17 @@ impl Store {
&self,
min_last_modified_at: Option<i64>,
partitions: usize,
shutdown_handler: Arc<ShutdownHandler>,
) -> impl Stream<Item = eyre::Result<StoredIris>> + '_ {
let count = self.count_irises().await.expect("Failed count_irises");
let partition_size = count.div_ceil(partitions).max(1);

let mut partition_streams = Vec::new();
for i in 0..partitions {
if shutdown_handler.is_shutting_down() {
tracing::info!("Shutdown triggered before processing chunk {}", i);
break;
}
// we start from ID 1
let start_id = 1 + partition_size * i;
let end_id = start_id + partition_size - 1;
Expand Down Expand Up @@ -539,14 +545,17 @@ mod tests {
#[tokio::test]
async fn test_store() -> Result<()> {
// Create a unique schema for this test.
let shutdown_handler = Arc::new(ShutdownHandler::new(60));
shutdown_handler.wait_for_shutdown_signal().await;
let shutdown_handler_2 = Arc::clone(&shutdown_handler);
let schema_name = temporary_name();
let store = Store::new(&test_db_url()?, &schema_name).await?;

let got: Vec<DbStoredIris> = store.stream_irises().await.try_collect().await?;
assert_eq!(got.len(), 0);

let got: Vec<DbStoredIris> = store
.stream_irises_par(Some(0), 2)
.stream_irises_par(Some(0), 2, shutdown_handler)
.await
.map_ok(|stored_iris| match stored_iris {
StoredIris::DB(db_iris) => db_iris,
Expand Down Expand Up @@ -589,7 +598,7 @@ mod tests {
let got: Vec<DbStoredIris> = store.stream_irises().await.try_collect().await?;

let mut got_par: Vec<DbStoredIris> = store
.stream_irises_par(Some(0), 2)
.stream_irises_par(Some(0), 2, shutdown_handler_2)
.await
.map_ok(|stored_iris| match stored_iris {
StoredIris::DB(db_iris) => db_iris,
Expand All @@ -615,6 +624,53 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_store_shutdown_handler() -> Result<()> {
let count = 50;
// Create a unique schema for this test.
let shutdown_handler = Arc::new(ShutdownHandler::new(60));
shutdown_handler.wait_for_shutdown_signal().await;
let shutdown_handler_2 = Arc::clone(&shutdown_handler);
let schema_name = temporary_name();
let store = Store::new(&test_db_url()?, &schema_name).await?;

let mut codes_and_masks = vec![];

for i in 0..count {
let iris = StoredIrisRef {
id: (i + 1) as i64,
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 12800],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 12800],
};
codes_and_masks.push(iris);
}
let mut tx = store.tx().await?;
store.insert_irises(&mut tx, &codes_and_masks).await?;
tx.commit().await?;

let got_len = store.count_irises().await?;
let got_par_process = store.stream_irises_par(Some(0), 25, shutdown_handler);

shutdown_handler_2.trigger_manual_shutdown();
let got_par: Vec<DbStoredIris> = got_par_process
.await
.map_ok(|stored_iris| match stored_iris {
StoredIris::DB(db_iris) => db_iris,
StoredIris::S3(_) => panic!("Unexpected S3 variant in this test!"),
})
.try_collect()
.await?;

assert_eq!(got_len, count);
assert!(got_par.len() < count);

// Clean up on success.
cleanup(&store, &schema_name).await?;
Ok(())
}

#[tokio::test]
async fn test_empty_insert() -> Result<()> {
let schema_name = temporary_name();
Expand All @@ -636,6 +692,8 @@ mod tests {

let schema_name = temporary_name();
let store = Store::new(&test_db_url()?, &schema_name).await?;
let shutdown_handler = Arc::new(ShutdownHandler::new(60));
shutdown_handler.wait_for_shutdown_signal().await;

let mut codes_and_masks = vec![];

Expand Down Expand Up @@ -674,7 +732,7 @@ mod tests {
// Compare with the parallel version with several edge-cases.
for parallelism in [1, 5, MAX_CONNECTIONS as usize + 1] {
let mut got_par: Vec<DbStoredIris> = store
.stream_irises_par(Some(0), parallelism)
.stream_irises_par(Some(0), parallelism, shutdown_handler.clone())
.await
.map_ok(|stored_iris| match stored_iris {
StoredIris::DB(db_iris) => db_iris,
Expand Down
81 changes: 78 additions & 3 deletions iris-mpc-store/src/s3_importer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use async_trait::async_trait;
use aws_sdk_s3::{primitives::ByteStream, Client};
use eyre::eyre;
use futures::{stream, Stream, StreamExt};
use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH};
use iris_mpc_common::{
helpers::shutdown_handler::ShutdownHandler, IRIS_CODE_LENGTH, MASK_CODE_LENGTH,
};
use std::{
mem,
pin::Pin,
Expand Down Expand Up @@ -236,6 +238,7 @@ pub async fn fetch_and_parse_chunks(
concurrency: usize,
prefix_name: String,
last_snapshot_details: LastSnapshotDetails,
shutdown_handler: Arc<ShutdownHandler>,
) -> Pin<Box<dyn Stream<Item = eyre::Result<StoredIris>> + Send + '_>> {
tracing::info!("Generating chunk files using: {:?}", last_snapshot_details);
let range_size = if last_snapshot_details.chunk_size as usize > MAX_RANGE_SIZE {
Expand All @@ -253,7 +256,12 @@ pub async fn fetch_and_parse_chunks(
move |chunk| {
let counter = total_bytes_clone.clone();
let prefix_name = prefix_name.clone();
let shutdown_handler_clone = shutdown_handler.clone();
async move {
if shutdown_handler_clone.is_shutting_down() {
tracing::info!("Shutdown triggered before processing chunk {}", chunk);
return Err(eyre::eyre!("Shutdown triggered"));
}
let chunk_id = (chunk / last_snapshot_details.chunk_size)
* last_snapshot_details.chunk_size
+ 1;
Expand Down Expand Up @@ -407,6 +415,8 @@ mod tests {
const MOCK_ENTRIES: usize = 107;
const MOCK_CHUNK_SIZE: usize = 10;
let mut store = MockStore::new();
let shutdown_handler = Arc::new(ShutdownHandler::new(60));
shutdown_handler.wait_for_shutdown_signal().await;
let n_chunks = MOCK_ENTRIES.div_ceil(MOCK_CHUNK_SIZE);
for i in 0..n_chunks {
let start_serial_id = i * MOCK_CHUNK_SIZE + 1;
Expand All @@ -423,8 +433,14 @@ mod tests {
last_serial_id: MOCK_ENTRIES as i64,
chunk_size: MOCK_CHUNK_SIZE as i64,
};
let mut chunks =
fetch_and_parse_chunks(&store, 1, "out".to_string(), last_snapshot_details).await;
let mut chunks = fetch_and_parse_chunks(
&store,
1,
"out".to_string(),
last_snapshot_details,
shutdown_handler,
)
.await;
let mut count = 0;
let mut ids: HashSet<usize> = HashSet::from_iter(1..MOCK_ENTRIES);
while let Some(chunk) = chunks.next().await {
Expand All @@ -435,4 +451,63 @@ mod tests {
assert_eq!(count, MOCK_ENTRIES);
assert!(ids.is_empty());
}

#[tokio::test]
async fn test_fetch_and_parse_chunks_shutdown_handler() {
const MOCK_ENTRIES: usize = 107;
const MOCK_CHUNK_SIZE: usize = 10;
let mut store = MockStore::new();
let shutdown_handler = Arc::new(ShutdownHandler::new(60));
shutdown_handler.wait_for_shutdown_signal().await;

let shutdown_handler_2 = Arc::clone(&shutdown_handler);
let n_chunks = MOCK_ENTRIES.div_ceil(MOCK_CHUNK_SIZE);
for i in 0..n_chunks {
let start_serial_id = i * MOCK_CHUNK_SIZE + 1;
let end_serial_id = min((i + 1) * MOCK_CHUNK_SIZE, MOCK_ENTRIES);
store.add_test_data(
&format!("out/{start_serial_id}.bin"),
(start_serial_id..=end_serial_id).map(dummy_entry).collect(),
);
}

assert_eq!(store.list_objects("").await.unwrap().len(), n_chunks);
let last_snapshot_details = LastSnapshotDetails {
timestamp: 0,
last_serial_id: MOCK_ENTRIES as i64,
chunk_size: MOCK_CHUNK_SIZE as i64,
};
let chunks_process = fetch_and_parse_chunks(
&store,
1,
"out".to_string(),
last_snapshot_details,
shutdown_handler,
);
shutdown_handler_2.trigger_manual_shutdown();
let mut found_error = false;
let mut chunks = chunks_process.await;
let mut count = 0;

while let Some(item_result) = chunks.next().await {
match item_result {
Ok(_) => {
// The chunk was fine, continue reading...
count += 1;
}
Err(e) => {
// We got an error - test passes
found_error = true;
println!("Received error as expected: {e:?}");
break;
}
}
}

assert!(
found_error,
"Expected an error from the stream, but it ended or returned only Ok items."
);
assert!(count < MOCK_ENTRIES);
}
}
27 changes: 23 additions & 4 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let load_chunks_parallelism = config.load_chunks_parallelism;
let db_chunks_bucket_name = config.db_chunks_bucket_name.clone();
let db_chunks_folder_name = config.db_chunks_folder_name.clone();
let download_shutdown_handler = Arc::clone(&shutdown_handler);

let (tx, rx) = oneshot::channel();
background_tasks.spawn_blocking(move || {
Expand Down Expand Up @@ -1068,6 +1069,8 @@ async fn server_main(config: Config) -> eyre::Result<()> {
"Initialize iris db: Loading from DB (parallelism: {})",
parallelism
);
let s3_shutdown_handler = Arc::clone(&download_shutdown_handler);
let post_download_shutdown_handler = Arc::clone(&download_shutdown_handler);
let s3_store = S3Store::new(s3_client_clone, db_chunks_bucket_name);
tokio::runtime::Handle::current().block_on(async {
let mut stream = match config.enable_s3_importer {
Expand All @@ -1091,25 +1094,37 @@ async fn server_main(config: Config) -> eyre::Result<()> {
load_chunks_parallelism,
db_chunks_folder_name,
last_snapshot_details,
s3_shutdown_handler,
)
.await
.boxed();

let stream_db = store
.stream_irises_par(Some(min_last_modified_at), parallelism)
.stream_irises_par(
Some(min_last_modified_at),
parallelism,
download_shutdown_handler,
)
.await
.boxed();

select_all(vec![stream_s3, stream_db])
}
false => {
tracing::info!("S3 importer disabled. Fetching only from db");
let stream_db =
store.stream_irises_par(None, parallelism).await.boxed();
let stream_db = store
.stream_irises_par(None, parallelism, download_shutdown_handler)
.await
.boxed();
select_all(vec![stream_db])
}
};

if post_download_shutdown_handler.is_shutting_down() {
tracing::warn!("Shutdown requested by shutdown_handler.");
return Err(eyre::eyre!("Shutdown requested"));
}

tracing::info!("Page-lock host memory");
let left_codes = actor.left_code_db_slices.code_gr.clone();
let right_codes = actor.right_code_db_slices.code_gr.clone();
Expand Down Expand Up @@ -1148,6 +1163,10 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let mut n_loaded_from_db = 0;
let mut n_loaded_from_s3 = 0;
while let Some(result) = stream.try_next().await? {
if post_download_shutdown_handler.is_shutting_down() {
tracing::warn!("Shutdown requested by shutdown_handler.");
return Err(eyre::eyre!("Shutdown requested"));
}
time_waiting_for_stream += now_load_summary.elapsed();
now_load_summary = Instant::now();
let index = result.index();
Expand Down Expand Up @@ -1282,7 +1301,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let sns_client_bg = sns_client.clone();
let config_bg = config.clone();
let store_bg = store.clone();
let shutdown_handler_bg = shutdown_handler.clone();
let shutdown_handler_bg = Arc::clone(&shutdown_handler);
let _result_sender_abort = background_tasks.spawn(async move {
while let Some(ServerJobResult {
merged_results,
Expand Down

0 comments on commit b952887

Please sign in to comment.