From 1d31218d781fa13bd79e7254e35895e48cdcff99 Mon Sep 17 00:00:00 2001 From: philsippl Date: Tue, 14 Jan 2025 16:43:52 +0000 Subject: [PATCH] fix sync_batch bug and make e2e enforce receiving results --- iris-mpc-gpu/src/server/actor.rs | 2 +- iris-mpc-gpu/tests/e2e.rs | 22 +++++++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 7cd5cfdd9..6ee983c5c 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -1499,7 +1499,7 @@ impl ServerActor { .collect(); // Only keep entries that are valid on all nodes - let mut valid_merged = vec![false; max_batch_size]; + let mut valid_merged = vec![true; max_batch_size]; for i in 0..self.comms[0].world_size() { for j in 0..max_batch_size { valid_merged[j] &= results[i][j] == 1; diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index df377fa08..cc2e3e79a 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -218,13 +218,13 @@ mod e2e_test { let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED); let mut expected_results: HashMap, bool)> = HashMap::new(); - let mut requests: HashMap = HashMap::new(); let mut responses: HashMap = HashMap::new(); let mut deleted_indices_buffer = vec![]; let mut deleted_indices: HashSet = HashSet::new(); let mut disallowed_queries = Vec::new(); for _ in 0..NUM_BATCHES { + let mut requests: HashMap = HashMap::new(); let mut batch0 = BatchQuery::default(); let mut batch1 = BatchQuery::default(); let mut batch2 = BatchQuery::default(); @@ -470,8 +470,13 @@ mod e2e_test { let res1 = res1_fut.await; let res2 = res2_fut.await; - // go over results and check if correct - for res in [res0, res1, res2].iter() { + let mut resp_counters = HashMap::new(); + for req in requests.keys() { + resp_counters.insert(req, 0); + } + + let results = [&res0, &res1, &res2]; + for res in results.iter() { let ServerJobResult { request_ids: thread_request_ids, matches, @@ -496,13 +501,11 @@ mod e2e_test { { assert!(requests.contains_key(req_id)); + resp_counters.insert(req_id, resp_counters.get(req_id).unwrap() + 1); + assert_eq!(partial_left, partial_right); assert_eq!(partial_left, match_id); - // This was an invalid query, we should not get a response, but they should be - // silently ignored - assert!(requests.contains_key(req_id)); - let (expected_idx, is_batch_match) = expected_results.get(req_id).unwrap(); if let Some(expected_idx) = expected_idx { @@ -521,6 +524,11 @@ mod e2e_test { } } } + + // Check that we received a response from all actors + for (&id, &count) in resp_counters.iter() { + assert_eq!(count, 3, "Received {} responses for {}", count, id); + } } drop(handle0);