Skip to content

Commit

Permalink
Batch hash (#932)
Browse files Browse the repository at this point in the history
* sync on batch hash

* log hash

* clippy

* clippy

* fix sync_batch bug and make e2e enforce receiving results
  • Loading branch information
philsippl authored Jan 15, 2025
1 parent 872ff05 commit 0b4ba75
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 49 deletions.
8 changes: 6 additions & 2 deletions iris-mpc-common/src/helpers/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use sha2::{Digest, Sha256};

pub fn calculate_sha256<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(Sha256::digest(data.as_ref()))
pub fn sha256_as_hex_string<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(sha256_bytes(data))
}

pub fn sha256_bytes<T: AsRef<[u8]>>(data: T) -> [u8; 32] {
Sha256::digest(data.as_ref()).into()
}
4 changes: 2 additions & 2 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{key_pair::SharesDecodingError, sha256::calculate_sha256};
use super::{key_pair::SharesDecodingError, sha256::sha256_as_hex_string};
use crate::helpers::key_pair::SharesEncryptionKeyPairs;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_sns::types::MessageAttributeValue;
Expand Down Expand Up @@ -285,6 +285,6 @@ impl UniquenessRequest {
.map_err(SharesDecodingError::SerdeError)?
.into_bytes();

Ok(self.iris_shares_file_hashes[party_id] == calculate_sha256(stringified_share))
Ok(self.iris_shares_file_hashes[party_id] == sha256_as_hex_string(stringified_share))
}
}
10 changes: 5 additions & 5 deletions iris-mpc-common/tests/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod tests {
use iris_mpc_common::helpers::sha256::calculate_sha256;
use iris_mpc_common::helpers::sha256::sha256_as_hex_string;

#[test]
fn test_calculate_sha256() {
Expand All @@ -8,7 +8,7 @@ mod tests {
let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -24,7 +24,7 @@ mod tests {
let expected_hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -38,8 +38,8 @@ mod tests {
// Arrange
let data_1 = "Data 1";
let data_2 = "Data 2";
let hash_1 = calculate_sha256(data_1);
let hash_2 = calculate_sha256(data_2);
let hash_1 = sha256_as_hex_string(data_1);
let hash_2 = sha256_as_hex_string(data_2);

// Act & Assert
assert_ne!(
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc-common/tests/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod tests {
use base64::{engine::general_purpose::STANDARD, Engine};
use iris_mpc_common::helpers::{
key_pair::{SharesDecodingError, SharesEncryptionKeyPairs},
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest},
};
use serde_json::json;
Expand Down Expand Up @@ -271,7 +271,7 @@ mod tests {
async fn test_validate_iris_share() {
let mock_iris_codes_json = mock_iris_codes_json();
let mock_serialized_iris = serde_json::to_string(&mock_iris_codes_json).unwrap();
let mock_hash = calculate_sha256(mock_serialized_iris.into_bytes());
let mock_hash = sha256_as_hex_string(mock_serialized_iris.into_bytes());

let smpc_request = get_mock_smpc_request_with_hashes([
mock_hash.clone(),
Expand Down
72 changes: 43 additions & 29 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use eyre::eyre;
use futures::{Future, FutureExt};
use iris_mpc_common::{
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
helpers::sha256::sha256_bytes,
iris_db::iris::IrisCode,
IrisCodeDbSlice,
};
Expand Down Expand Up @@ -652,7 +653,13 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
let tmp_now = Instant::now();
tracing::info!("Syncing batch entries");
let valid_entries = self.sync_batch_entries(&batch.valid_entries)?;

// Compute hash of the request ids concatenated
let batch_hash = sha256_bytes(batch.request_ids.join(""));
tracing::info!("Current batch hash: {}", hex::encode(&batch_hash[0..4]));

let valid_entries =
self.sync_batch_entries(&batch.valid_entries, self.max_batch_size, &batch_hash)?;
let valid_entry_idxs = valid_entries.iter().positions(|&x| x).collect::<Vec<_>>();
batch_size = valid_entry_idxs.len();
batch.retain(&valid_entry_idxs);
Expand Down Expand Up @@ -1458,53 +1465,60 @@ impl ServerActor {
}
}

fn sync_batch_entries(&mut self, valid_entries: &[bool]) -> eyre::Result<Vec<bool>> {
tracing::info!(
party_id = self.party_id,
"valid_entries {:?} ({})",
valid_entries,
valid_entries.len()
);
tracing::info!(party_id = self.party_id, "sync_batch_entries start");
fn sync_batch_entries(
&mut self,
valid_entries: &[bool],
max_batch_size: usize,
batch_hash: &[u8],
) -> eyre::Result<Vec<bool>> {
assert!(valid_entries.len() <= max_batch_size);
let hash_len = batch_hash.len();
let mut buffer = self
.device_manager
.device(0)
.alloc_zeros(valid_entries.len() * self.comms[0].world_size())
.alloc_zeros((max_batch_size + hash_len) * self.comms[0].world_size())
.unwrap();

tracing::info!(party_id = self.party_id, "htod_copy start");
let mut host_buffer = vec![0u8; max_batch_size + hash_len];
host_buffer[..valid_entries.len()]
.copy_from_slice(&valid_entries.iter().map(|&x| x as u8).collect::<Vec<u8>>());
host_buffer[max_batch_size..].copy_from_slice(batch_hash);

let buffer_self = self
.device_manager
.device(0)
.htod_copy(valid_entries.iter().map(|&x| x as u8).collect::<Vec<_>>())?;
let buffer_self = self.device_manager.device(0).htod_copy(host_buffer)?;

// Use all_gather to sync the buffer across all nodes (only using device 0)
self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "all_gather start");

self.comms[0]
.all_gather(&buffer_self, &mut buffer)
.map_err(|e| eyre!(format!("{:?}", e)))?;

self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "dtoh_sync_copy start");

let results = self.device_manager.device(0).dtoh_sync_copy(&buffer)?;
let results: Vec<_> = results
.chunks_exact(results.len() / self.comms[0].world_size())
.collect();

tracing::info!(party_id = self.party_id, "sync_batch_entries end");
// Only keep entries that are valid on all nodes
let mut valid_merged = vec![true; max_batch_size];
for i in 0..self.comms[0].world_size() {
for j in 0..max_batch_size {
valid_merged[j] &= results[i][j] == 1;
}
}

let mut valid_merged = vec![];
for i in 0..results[0].len() {
valid_merged.push(
[results[0][i], results[1][i], results[2][i]]
.iter()
.all(|&x| x == 1),
);
// Check that the hash is the same on nodes
for i in 0..self.comms[0].world_size() {
if &results[i][max_batch_size..] != batch_hash {
tracing::error!(
party_id = self.party_id,
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
);
return Err(eyre!(
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
));
}
}

Ok(valid_merged)
Expand Down
22 changes: 15 additions & 7 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ mod e2e_test {
let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED);

let mut expected_results: HashMap<String, (Option<u32>, bool)> = HashMap::new();
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut responses: HashMap<u32, IrisCode> = HashMap::new();
let mut deleted_indices_buffer = vec![];
let mut deleted_indices: HashSet<u32> = HashSet::new();
let mut disallowed_queries = Vec::new();

for _ in 0..NUM_BATCHES {
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut batch0 = BatchQuery::default();
let mut batch1 = BatchQuery::default();
let mut batch2 = BatchQuery::default();
Expand Down Expand Up @@ -470,8 +470,13 @@ mod e2e_test {
let res1 = res1_fut.await;
let res2 = res2_fut.await;

// go over results and check if correct
for res in [res0, res1, res2].iter() {
let mut resp_counters = HashMap::new();
for req in requests.keys() {
resp_counters.insert(req, 0);
}

let results = [&res0, &res1, &res2];
for res in results.iter() {
let ServerJobResult {
request_ids: thread_request_ids,
matches,
Expand All @@ -496,13 +501,11 @@ mod e2e_test {
{
assert!(requests.contains_key(req_id));

resp_counters.insert(req_id, resp_counters.get(req_id).unwrap() + 1);

assert_eq!(partial_left, partial_right);
assert_eq!(partial_left, match_id);

// This was an invalid query, we should not get a response, but they should be
// silently ignored
assert!(requests.contains_key(req_id));

let (expected_idx, is_batch_match) = expected_results.get(req_id).unwrap();

if let Some(expected_idx) = expected_idx {
Expand All @@ -521,6 +524,11 @@ mod e2e_test {
}
}
}

// Check that we received a response from all actors
for (&id, &count) in resp_counters.iter() {
assert_eq!(count, 3, "Received {} responses for {}", count, id);
}
}

drop(handle0);
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use iris_mpc_common::{
galois_engine::degree4::GaloisRingIrisCodeShare,
helpers::{
key_pair::download_public_key,
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest, UNIQUENESS_MESSAGE_TYPE},
smpc_response::{create_message_type_attribute_map, UniquenessResult},
sqs_s3_helper::upload_file_and_generate_presigned_url,
Expand Down Expand Up @@ -339,7 +339,7 @@ async fn main() -> eyre::Result<()> {
.clone();

// calculate hash of the object
let hash_string = calculate_sha256(&serialized_iris_codes_json);
let hash_string = sha256_as_hex_string(&serialized_iris_codes_json);

// encrypt the object using sealed box and public key
let encrypted_bytes = sealedbox::seal(
Expand Down

0 comments on commit 0b4ba75

Please sign in to comment.