Skip to content

Commit

Permalink
fix sync_batch bug and make e2e enforce receiving results
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Jan 14, 2025
1 parent 7c1b836 commit 1d31218
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
2 changes: 1 addition & 1 deletion iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1499,7 +1499,7 @@ impl ServerActor {
.collect();

// Only keep entries that are valid on all nodes
let mut valid_merged = vec![false; max_batch_size];
let mut valid_merged = vec![true; max_batch_size];
for i in 0..self.comms[0].world_size() {
for j in 0..max_batch_size {
valid_merged[j] &= results[i][j] == 1;
Expand Down
22 changes: 15 additions & 7 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ mod e2e_test {
let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED);

let mut expected_results: HashMap<String, (Option<u32>, bool)> = HashMap::new();
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut responses: HashMap<u32, IrisCode> = HashMap::new();
let mut deleted_indices_buffer = vec![];
let mut deleted_indices: HashSet<u32> = HashSet::new();
let mut disallowed_queries = Vec::new();

for _ in 0..NUM_BATCHES {
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut batch0 = BatchQuery::default();
let mut batch1 = BatchQuery::default();
let mut batch2 = BatchQuery::default();
Expand Down Expand Up @@ -470,8 +470,13 @@ mod e2e_test {
let res1 = res1_fut.await;
let res2 = res2_fut.await;

// go over results and check if correct
for res in [res0, res1, res2].iter() {
let mut resp_counters = HashMap::new();
for req in requests.keys() {
resp_counters.insert(req, 0);
}

let results = [&res0, &res1, &res2];
for res in results.iter() {
let ServerJobResult {
request_ids: thread_request_ids,
matches,
Expand All @@ -496,13 +501,11 @@ mod e2e_test {
{
assert!(requests.contains_key(req_id));

resp_counters.insert(req_id, resp_counters.get(req_id).unwrap() + 1);

assert_eq!(partial_left, partial_right);
assert_eq!(partial_left, match_id);

// This was an invalid query, we should not get a response, but they should be
// silently ignored
assert!(requests.contains_key(req_id));

let (expected_idx, is_batch_match) = expected_results.get(req_id).unwrap();

if let Some(expected_idx) = expected_idx {
Expand All @@ -521,6 +524,11 @@ mod e2e_test {
}
}
}

// Check that we received a response from all actors
for (&id, &count) in resp_counters.iter() {
assert_eq!(count, 3, "Received {} responses for {}", count, id);
}
}

drop(handle0);
Expand Down

0 comments on commit 1d31218

Please sign in to comment.