diff --git a/iris-mpc-common/src/helpers/sha256.rs b/iris-mpc-common/src/helpers/sha256.rs index 1f36637fd..78cc93055 100644 --- a/iris-mpc-common/src/helpers/sha256.rs +++ b/iris-mpc-common/src/helpers/sha256.rs @@ -1,5 +1,9 @@ use sha2::{Digest, Sha256}; -pub fn calculate_sha256>(data: T) -> String { - hex::encode(Sha256::digest(data.as_ref())) +pub fn sha256_as_hex_string>(data: T) -> String { + hex::encode(sha256_bytes(data)) +} + +pub fn sha256_bytes>(data: T) -> [u8; 32] { + Sha256::digest(data.as_ref()).into() } diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index 04863df66..9391688ab 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -1,4 +1,4 @@ -use super::{key_pair::SharesDecodingError, sha256::calculate_sha256}; +use super::{key_pair::SharesDecodingError, sha256::sha256_as_hex_string}; use crate::helpers::key_pair::SharesEncryptionKeyPairs; use aws_sdk_s3::Client as S3Client; use aws_sdk_sns::types::MessageAttributeValue; @@ -285,6 +285,6 @@ impl UniquenessRequest { .map_err(SharesDecodingError::SerdeError)? .into_bytes(); - Ok(self.iris_shares_file_hashes[party_id] == calculate_sha256(stringified_share)) + Ok(self.iris_shares_file_hashes[party_id] == sha256_as_hex_string(stringified_share)) } } diff --git a/iris-mpc-common/tests/sha256.rs b/iris-mpc-common/tests/sha256.rs index eb922e0e7..759656e15 100644 --- a/iris-mpc-common/tests/sha256.rs +++ b/iris-mpc-common/tests/sha256.rs @@ -1,5 +1,5 @@ mod tests { - use iris_mpc_common::helpers::sha256::calculate_sha256; + use iris_mpc_common::helpers::sha256::sha256_as_hex_string; #[test] fn test_calculate_sha256() { @@ -8,7 +8,7 @@ mod tests { let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3"; // Act - let calculated_hash = calculate_sha256(data); + let calculated_hash = sha256_as_hex_string(data); // Assert assert_eq!( @@ -24,7 +24,7 @@ mod tests { let expected_hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; // Act - let calculated_hash = calculate_sha256(data); + let calculated_hash = sha256_as_hex_string(data); // Assert assert_eq!( @@ -38,8 +38,8 @@ mod tests { // Arrange let data_1 = "Data 1"; let data_2 = "Data 2"; - let hash_1 = calculate_sha256(data_1); - let hash_2 = calculate_sha256(data_2); + let hash_1 = sha256_as_hex_string(data_1); + let hash_2 = sha256_as_hex_string(data_2); // Act & Assert assert_ne!( diff --git a/iris-mpc-common/tests/smpc_request.rs b/iris-mpc-common/tests/smpc_request.rs index 1c2e7d5fd..1df6149d1 100644 --- a/iris-mpc-common/tests/smpc_request.rs +++ b/iris-mpc-common/tests/smpc_request.rs @@ -4,7 +4,7 @@ mod tests { use base64::{engine::general_purpose::STANDARD, Engine}; use iris_mpc_common::helpers::{ key_pair::{SharesDecodingError, SharesEncryptionKeyPairs}, - sha256::calculate_sha256, + sha256::sha256_as_hex_string, smpc_request::{IrisCodesJSON, UniquenessRequest}, }; use serde_json::json; @@ -271,7 +271,7 @@ mod tests { async fn test_validate_iris_share() { let mock_iris_codes_json = mock_iris_codes_json(); let mock_serialized_iris = serde_json::to_string(&mock_iris_codes_json).unwrap(); - let mock_hash = calculate_sha256(mock_serialized_iris.into_bytes()); + let mock_hash = sha256_as_hex_string(mock_serialized_iris.into_bytes()); let smpc_request = get_mock_smpc_request_with_hashes([ mock_hash.clone(), diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 6d5a7f943..6ee983c5c 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -27,6 +27,7 @@ use eyre::eyre; use futures::{Future, FutureExt}; use iris_mpc_common::{ galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare}, + helpers::sha256::sha256_bytes, iris_db::iris::IrisCode, IrisCodeDbSlice, }; @@ -652,7 +653,13 @@ impl ServerActor { /////////////////////////////////////////////////////////////////// let tmp_now = Instant::now(); tracing::info!("Syncing batch entries"); - let valid_entries = self.sync_batch_entries(&batch.valid_entries)?; + + // Compute hash of the request ids concatenated + let batch_hash = sha256_bytes(batch.request_ids.join("")); + tracing::info!("Current batch hash: {}", hex::encode(&batch_hash[0..4])); + + let valid_entries = + self.sync_batch_entries(&batch.valid_entries, self.max_batch_size, &batch_hash)?; let valid_entry_idxs = valid_entries.iter().positions(|&x| x).collect::>(); batch_size = valid_entry_idxs.len(); batch.retain(&valid_entry_idxs); @@ -1458,53 +1465,60 @@ impl ServerActor { } } - fn sync_batch_entries(&mut self, valid_entries: &[bool]) -> eyre::Result> { - tracing::info!( - party_id = self.party_id, - "valid_entries {:?} ({})", - valid_entries, - valid_entries.len() - ); - tracing::info!(party_id = self.party_id, "sync_batch_entries start"); + fn sync_batch_entries( + &mut self, + valid_entries: &[bool], + max_batch_size: usize, + batch_hash: &[u8], + ) -> eyre::Result> { + assert!(valid_entries.len() <= max_batch_size); + let hash_len = batch_hash.len(); let mut buffer = self .device_manager .device(0) - .alloc_zeros(valid_entries.len() * self.comms[0].world_size()) + .alloc_zeros((max_batch_size + hash_len) * self.comms[0].world_size()) .unwrap(); - tracing::info!(party_id = self.party_id, "htod_copy start"); + let mut host_buffer = vec![0u8; max_batch_size + hash_len]; + host_buffer[..valid_entries.len()] + .copy_from_slice(&valid_entries.iter().map(|&x| x as u8).collect::>()); + host_buffer[max_batch_size..].copy_from_slice(batch_hash); - let buffer_self = self - .device_manager - .device(0) - .htod_copy(valid_entries.iter().map(|&x| x as u8).collect::>())?; + let buffer_self = self.device_manager.device(0).htod_copy(host_buffer)?; + // Use all_gather to sync the buffer across all nodes (only using device 0) self.device_manager.device(0).synchronize()?; - - tracing::info!(party_id = self.party_id, "all_gather start"); - self.comms[0] .all_gather(&buffer_self, &mut buffer) .map_err(|e| eyre!(format!("{:?}", e)))?; - self.device_manager.device(0).synchronize()?; - tracing::info!(party_id = self.party_id, "dtoh_sync_copy start"); - let results = self.device_manager.device(0).dtoh_sync_copy(&buffer)?; let results: Vec<_> = results .chunks_exact(results.len() / self.comms[0].world_size()) .collect(); - tracing::info!(party_id = self.party_id, "sync_batch_entries end"); + // Only keep entries that are valid on all nodes + let mut valid_merged = vec![true; max_batch_size]; + for i in 0..self.comms[0].world_size() { + for j in 0..max_batch_size { + valid_merged[j] &= results[i][j] == 1; + } + } - let mut valid_merged = vec![]; - for i in 0..results[0].len() { - valid_merged.push( - [results[0][i], results[1][i], results[2][i]] - .iter() - .all(|&x| x == 1), - ); + // Check that the hash is the same on nodes + for i in 0..self.comms[0].world_size() { + if &results[i][max_batch_size..] != batch_hash { + tracing::error!( + party_id = self.party_id, + "Batch mismatch with node {}. Queues seem to be out of sync.", + i + ); + return Err(eyre!( + "Batch mismatch with node {}. Queues seem to be out of sync.", + i + )); + } } Ok(valid_merged) diff --git a/iris-mpc-gpu/tests/e2e.rs b/iris-mpc-gpu/tests/e2e.rs index df377fa08..cc2e3e79a 100644 --- a/iris-mpc-gpu/tests/e2e.rs +++ b/iris-mpc-gpu/tests/e2e.rs @@ -218,13 +218,13 @@ mod e2e_test { let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED); let mut expected_results: HashMap, bool)> = HashMap::new(); - let mut requests: HashMap = HashMap::new(); let mut responses: HashMap = HashMap::new(); let mut deleted_indices_buffer = vec![]; let mut deleted_indices: HashSet = HashSet::new(); let mut disallowed_queries = Vec::new(); for _ in 0..NUM_BATCHES { + let mut requests: HashMap = HashMap::new(); let mut batch0 = BatchQuery::default(); let mut batch1 = BatchQuery::default(); let mut batch2 = BatchQuery::default(); @@ -470,8 +470,13 @@ mod e2e_test { let res1 = res1_fut.await; let res2 = res2_fut.await; - // go over results and check if correct - for res in [res0, res1, res2].iter() { + let mut resp_counters = HashMap::new(); + for req in requests.keys() { + resp_counters.insert(req, 0); + } + + let results = [&res0, &res1, &res2]; + for res in results.iter() { let ServerJobResult { request_ids: thread_request_ids, matches, @@ -496,13 +501,11 @@ mod e2e_test { { assert!(requests.contains_key(req_id)); + resp_counters.insert(req_id, resp_counters.get(req_id).unwrap() + 1); + assert_eq!(partial_left, partial_right); assert_eq!(partial_left, match_id); - // This was an invalid query, we should not get a response, but they should be - // silently ignored - assert!(requests.contains_key(req_id)); - let (expected_idx, is_batch_match) = expected_results.get(req_id).unwrap(); if let Some(expected_idx) = expected_idx { @@ -521,6 +524,11 @@ mod e2e_test { } } } + + // Check that we received a response from all actors + for (&id, &count) in resp_counters.iter() { + assert_eq!(count, 3, "Received {} responses for {}", count, id); + } } drop(handle0); diff --git a/iris-mpc/src/bin/client.rs b/iris-mpc/src/bin/client.rs index cc0cf0529..44b1e6c41 100644 --- a/iris-mpc/src/bin/client.rs +++ b/iris-mpc/src/bin/client.rs @@ -9,7 +9,7 @@ use iris_mpc_common::{ galois_engine::degree4::GaloisRingIrisCodeShare, helpers::{ key_pair::download_public_key, - sha256::calculate_sha256, + sha256::sha256_as_hex_string, smpc_request::{IrisCodesJSON, UniquenessRequest, UNIQUENESS_MESSAGE_TYPE}, smpc_response::{create_message_type_attribute_map, UniquenessResult}, sqs_s3_helper::upload_file_and_generate_presigned_url, @@ -339,7 +339,7 @@ async fn main() -> eyre::Result<()> { .clone(); // calculate hash of the object - let hash_string = calculate_sha256(&serialized_iris_codes_json); + let hash_string = sha256_as_hex_string(&serialized_iris_codes_json); // encrypt the object using sealed box and public key let encrypted_bytes = sealedbox::seal(