Skip to content

Commit

Permalink
sync on batch hash
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Jan 14, 2025
1 parent 90819bf commit 7092d44
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 42 deletions.
8 changes: 6 additions & 2 deletions iris-mpc-common/src/helpers/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use sha2::{Digest, Sha256};

pub fn calculate_sha256<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(Sha256::digest(data.as_ref()))
pub fn sha256_as_hex_string<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(sha256_bytes(data))
}

pub fn sha256_bytes<T: AsRef<[u8]>>(data: T) -> [u8; 32] {
Sha256::digest(data.as_ref()).into()
}
4 changes: 2 additions & 2 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{key_pair::SharesDecodingError, sha256::calculate_sha256};
use super::{key_pair::SharesDecodingError, sha256::sha256_as_hex_string};
use crate::helpers::key_pair::SharesEncryptionKeyPairs;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_sns::types::MessageAttributeValue;
Expand Down Expand Up @@ -285,6 +285,6 @@ impl UniquenessRequest {
.map_err(SharesDecodingError::SerdeError)?
.into_bytes();

Ok(self.iris_shares_file_hashes[party_id] == calculate_sha256(stringified_share))
Ok(self.iris_shares_file_hashes[party_id] == sha256_as_hex_string(stringified_share))
}
}
10 changes: 5 additions & 5 deletions iris-mpc-common/tests/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod tests {
use iris_mpc_common::helpers::sha256::calculate_sha256;
use iris_mpc_common::helpers::sha256::sha256_as_hex_string;

#[test]
fn test_calculate_sha256() {
Expand All @@ -8,7 +8,7 @@ mod tests {
let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -24,7 +24,7 @@ mod tests {
let expected_hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -38,8 +38,8 @@ mod tests {
// Arrange
let data_1 = "Data 1";
let data_2 = "Data 2";
let hash_1 = calculate_sha256(data_1);
let hash_2 = calculate_sha256(data_2);
let hash_1 = sha256_as_hex_string(data_1);
let hash_2 = sha256_as_hex_string(data_2);

// Act & Assert
assert_ne!(
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc-common/tests/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod tests {
use base64::{engine::general_purpose::STANDARD, Engine};
use iris_mpc_common::helpers::{
key_pair::{SharesDecodingError, SharesEncryptionKeyPairs},
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest},
};
use serde_json::json;
Expand Down Expand Up @@ -271,7 +271,7 @@ mod tests {
async fn test_validate_iris_share() {
let mock_iris_codes_json = mock_iris_codes_json();
let mock_serialized_iris = serde_json::to_string(&mock_iris_codes_json).unwrap();
let mock_hash = calculate_sha256(mock_serialized_iris.into_bytes());
let mock_hash = sha256_as_hex_string(mock_serialized_iris.into_bytes());

let smpc_request = get_mock_smpc_request_with_hashes([
mock_hash.clone(),
Expand Down
72 changes: 43 additions & 29 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use eyre::eyre;
use futures::{Future, FutureExt};
use iris_mpc_common::{
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
helpers::sha256::sha256_bytes,
iris_db::iris::IrisCode,
IrisCodeDbSlice,
};
Expand Down Expand Up @@ -652,7 +653,13 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
let tmp_now = Instant::now();
tracing::info!("Syncing batch entries");
let valid_entries = self.sync_batch_entries(&batch.valid_entries)?;

// Compute hash of the request ids concatenated
let batch_hash = sha256_bytes(&batch.request_ids.join(""));
tracing::info!("Current batch hash: 0x{}", hex::encode(batch_hash));

let valid_entries =
self.sync_batch_entries(&batch.valid_entries, self.max_batch_size, &batch_hash)?;
let valid_entry_idxs = valid_entries.iter().positions(|&x| x).collect::<Vec<_>>();
batch_size = valid_entry_idxs.len();
batch.retain(&valid_entry_idxs);
Expand Down Expand Up @@ -1458,53 +1465,60 @@ impl ServerActor {
}
}

fn sync_batch_entries(&mut self, valid_entries: &[bool]) -> eyre::Result<Vec<bool>> {
tracing::info!(
party_id = self.party_id,
"valid_entries {:?} ({})",
valid_entries,
valid_entries.len()
);
tracing::info!(party_id = self.party_id, "sync_batch_entries start");
fn sync_batch_entries(
&mut self,
valid_entries: &[bool],
max_batch_size: usize,
batch_hash: &[u8],
) -> eyre::Result<Vec<bool>> {
assert!(valid_entries.len() <= max_batch_size);
let hash_len = batch_hash.len();
let mut buffer = self
.device_manager
.device(0)
.alloc_zeros(valid_entries.len() * self.comms[0].world_size())
.alloc_zeros((max_batch_size + hash_len) * self.comms[0].world_size())
.unwrap();

tracing::info!(party_id = self.party_id, "htod_copy start");
let mut host_buffer = vec![0u8; max_batch_size + hash_len];
host_buffer[..valid_entries.len()]
.copy_from_slice(&valid_entries.iter().map(|&x| x as u8).collect::<Vec<u8>>());
host_buffer[max_batch_size..].copy_from_slice(&batch_hash);

let buffer_self = self
.device_manager
.device(0)
.htod_copy(valid_entries.iter().map(|&x| x as u8).collect::<Vec<_>>())?;
let buffer_self = self.device_manager.device(0).htod_copy(host_buffer)?;

// Use all_gather to sync the buffer across all nodes (only using device 0)
self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "all_gather start");

self.comms[0]
.all_gather(&buffer_self, &mut buffer)
.map_err(|e| eyre!(format!("{:?}", e)))?;

self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "dtoh_sync_copy start");

let results = self.device_manager.device(0).dtoh_sync_copy(&buffer)?;
let results: Vec<_> = results
.chunks_exact(results.len() / self.comms[0].world_size())
.collect();

tracing::info!(party_id = self.party_id, "sync_batch_entries end");
// Only keep entries that are valid on all nodes
let mut valid_merged = vec![false; max_batch_size];
for i in 0..self.comms[0].world_size() {
for j in 0..max_batch_size {
valid_merged[j] &= results[i][j] == 1;
}
}

let mut valid_merged = vec![];
for i in 0..results[0].len() {
valid_merged.push(
[results[0][i], results[1][i], results[2][i]]
.iter()
.all(|&x| x == 1),
);
// Check that the hash is the same on nodes
for i in 0..self.comms[0].world_size() {
if &results[i][max_batch_size..] != batch_hash {
tracing::error!(
party_id = self.party_id,
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
);
return Err(eyre!(
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
));
}
}

Ok(valid_merged)
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use iris_mpc_common::{
galois_engine::degree4::GaloisRingIrisCodeShare,
helpers::{
key_pair::download_public_key,
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest, UNIQUENESS_MESSAGE_TYPE},
smpc_response::{create_message_type_attribute_map, UniquenessResult},
sqs_s3_helper::upload_file_and_generate_presigned_url,
Expand Down Expand Up @@ -339,7 +339,7 @@ async fn main() -> eyre::Result<()> {
.clone();

// calculate hash of the object
let hash_string = calculate_sha256(&serialized_iris_codes_json);
let hash_string = sha256_as_hex_string(&serialized_iris_codes_json);

// encrypt the object using sealed box and public key
let encrypted_bytes = sealedbox::seal(
Expand Down

0 comments on commit 7092d44

Please sign in to comment.