Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch hash #932

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions iris-mpc-common/src/helpers/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use sha2::{Digest, Sha256};

pub fn calculate_sha256<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(Sha256::digest(data.as_ref()))
pub fn sha256_as_hex_string<T: AsRef<[u8]>>(data: T) -> String {
hex::encode(sha256_bytes(data))
}

pub fn sha256_bytes<T: AsRef<[u8]>>(data: T) -> [u8; 32] {
Sha256::digest(data.as_ref()).into()
}
4 changes: 2 additions & 2 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{key_pair::SharesDecodingError, sha256::calculate_sha256};
use super::{key_pair::SharesDecodingError, sha256::sha256_as_hex_string};
use crate::helpers::key_pair::SharesEncryptionKeyPairs;
use aws_sdk_s3::Client as S3Client;
use aws_sdk_sns::types::MessageAttributeValue;
Expand Down Expand Up @@ -285,6 +285,6 @@ impl UniquenessRequest {
.map_err(SharesDecodingError::SerdeError)?
.into_bytes();

Ok(self.iris_shares_file_hashes[party_id] == calculate_sha256(stringified_share))
Ok(self.iris_shares_file_hashes[party_id] == sha256_as_hex_string(stringified_share))
}
}
10 changes: 5 additions & 5 deletions iris-mpc-common/tests/sha256.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod tests {
use iris_mpc_common::helpers::sha256::calculate_sha256;
use iris_mpc_common::helpers::sha256::sha256_as_hex_string;

#[test]
fn test_calculate_sha256() {
Expand All @@ -8,7 +8,7 @@ mod tests {
let expected_hash = "315f5bdb76d078c43b8ac0064e4a0164612b1fce77c869345bfc94c75894edd3";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -24,7 +24,7 @@ mod tests {
let expected_hash = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";

// Act
let calculated_hash = calculate_sha256(data);
let calculated_hash = sha256_as_hex_string(data);

// Assert
assert_eq!(
Expand All @@ -38,8 +38,8 @@ mod tests {
// Arrange
let data_1 = "Data 1";
let data_2 = "Data 2";
let hash_1 = calculate_sha256(data_1);
let hash_2 = calculate_sha256(data_2);
let hash_1 = sha256_as_hex_string(data_1);
let hash_2 = sha256_as_hex_string(data_2);

// Act & Assert
assert_ne!(
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc-common/tests/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod tests {
use base64::{engine::general_purpose::STANDARD, Engine};
use iris_mpc_common::helpers::{
key_pair::{SharesDecodingError, SharesEncryptionKeyPairs},
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest},
};
use serde_json::json;
Expand Down Expand Up @@ -271,7 +271,7 @@ mod tests {
async fn test_validate_iris_share() {
let mock_iris_codes_json = mock_iris_codes_json();
let mock_serialized_iris = serde_json::to_string(&mock_iris_codes_json).unwrap();
let mock_hash = calculate_sha256(mock_serialized_iris.into_bytes());
let mock_hash = sha256_as_hex_string(mock_serialized_iris.into_bytes());

let smpc_request = get_mock_smpc_request_with_hashes([
mock_hash.clone(),
Expand Down
72 changes: 43 additions & 29 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use eyre::eyre;
use futures::{Future, FutureExt};
use iris_mpc_common::{
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
helpers::sha256::sha256_bytes,
iris_db::iris::IrisCode,
IrisCodeDbSlice,
};
Expand Down Expand Up @@ -652,7 +653,13 @@ impl ServerActor {
///////////////////////////////////////////////////////////////////
let tmp_now = Instant::now();
tracing::info!("Syncing batch entries");
let valid_entries = self.sync_batch_entries(&batch.valid_entries)?;

// Compute hash of the request ids concatenated
let batch_hash = sha256_bytes(batch.request_ids.join(""));
tracing::info!("Current batch hash: {}", hex::encode(&batch_hash[0..4]));

let valid_entries =
self.sync_batch_entries(&batch.valid_entries, self.max_batch_size, &batch_hash)?;
let valid_entry_idxs = valid_entries.iter().positions(|&x| x).collect::<Vec<_>>();
batch_size = valid_entry_idxs.len();
batch.retain(&valid_entry_idxs);
Expand Down Expand Up @@ -1458,53 +1465,60 @@ impl ServerActor {
}
}

fn sync_batch_entries(&mut self, valid_entries: &[bool]) -> eyre::Result<Vec<bool>> {
tracing::info!(
party_id = self.party_id,
"valid_entries {:?} ({})",
valid_entries,
valid_entries.len()
);
tracing::info!(party_id = self.party_id, "sync_batch_entries start");
fn sync_batch_entries(
&mut self,
valid_entries: &[bool],
max_batch_size: usize,
batch_hash: &[u8],
) -> eyre::Result<Vec<bool>> {
assert!(valid_entries.len() <= max_batch_size);
let hash_len = batch_hash.len();
let mut buffer = self
.device_manager
.device(0)
.alloc_zeros(valid_entries.len() * self.comms[0].world_size())
.alloc_zeros((max_batch_size + hash_len) * self.comms[0].world_size())
.unwrap();

tracing::info!(party_id = self.party_id, "htod_copy start");
let mut host_buffer = vec![0u8; max_batch_size + hash_len];
host_buffer[..valid_entries.len()]
.copy_from_slice(&valid_entries.iter().map(|&x| x as u8).collect::<Vec<u8>>());
host_buffer[max_batch_size..].copy_from_slice(batch_hash);

let buffer_self = self
.device_manager
.device(0)
.htod_copy(valid_entries.iter().map(|&x| x as u8).collect::<Vec<_>>())?;
let buffer_self = self.device_manager.device(0).htod_copy(host_buffer)?;

// Use all_gather to sync the buffer across all nodes (only using device 0)
self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "all_gather start");

self.comms[0]
.all_gather(&buffer_self, &mut buffer)
.map_err(|e| eyre!(format!("{:?}", e)))?;

self.device_manager.device(0).synchronize()?;

tracing::info!(party_id = self.party_id, "dtoh_sync_copy start");

let results = self.device_manager.device(0).dtoh_sync_copy(&buffer)?;
let results: Vec<_> = results
.chunks_exact(results.len() / self.comms[0].world_size())
.collect();

tracing::info!(party_id = self.party_id, "sync_batch_entries end");
// Only keep entries that are valid on all nodes
let mut valid_merged = vec![true; max_batch_size];
for i in 0..self.comms[0].world_size() {
for j in 0..max_batch_size {
valid_merged[j] &= results[i][j] == 1;
}
}

let mut valid_merged = vec![];
for i in 0..results[0].len() {
valid_merged.push(
[results[0][i], results[1][i], results[2][i]]
.iter()
.all(|&x| x == 1),
);
// Check that the hash is the same on nodes
for i in 0..self.comms[0].world_size() {
if &results[i][max_batch_size..] != batch_hash {
tracing::error!(
party_id = self.party_id,
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
);
return Err(eyre!(
"Batch mismatch with node {}. Queues seem to be out of sync.",
i
));
}
}

Ok(valid_merged)
Expand Down
22 changes: 15 additions & 7 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,13 @@ mod e2e_test {
let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED);

let mut expected_results: HashMap<String, (Option<u32>, bool)> = HashMap::new();
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut responses: HashMap<u32, IrisCode> = HashMap::new();
let mut deleted_indices_buffer = vec![];
let mut deleted_indices: HashSet<u32> = HashSet::new();
let mut disallowed_queries = Vec::new();

for _ in 0..NUM_BATCHES {
let mut requests: HashMap<String, IrisCode> = HashMap::new();
let mut batch0 = BatchQuery::default();
let mut batch1 = BatchQuery::default();
let mut batch2 = BatchQuery::default();
Expand Down Expand Up @@ -470,8 +470,13 @@ mod e2e_test {
let res1 = res1_fut.await;
let res2 = res2_fut.await;

// go over results and check if correct
for res in [res0, res1, res2].iter() {
let mut resp_counters = HashMap::new();
for req in requests.keys() {
resp_counters.insert(req, 0);
}

let results = [&res0, &res1, &res2];
for res in results.iter() {
let ServerJobResult {
request_ids: thread_request_ids,
matches,
Expand All @@ -496,13 +501,11 @@ mod e2e_test {
{
assert!(requests.contains_key(req_id));

resp_counters.insert(req_id, resp_counters.get(req_id).unwrap() + 1);

assert_eq!(partial_left, partial_right);
assert_eq!(partial_left, match_id);

// This was an invalid query, we should not get a response, but they should be
// silently ignored
assert!(requests.contains_key(req_id));

let (expected_idx, is_batch_match) = expected_results.get(req_id).unwrap();

if let Some(expected_idx) = expected_idx {
Expand All @@ -521,6 +524,11 @@ mod e2e_test {
}
}
}

// Check that we received a response from all actors
for (&id, &count) in resp_counters.iter() {
assert_eq!(count, 3, "Received {} responses for {}", count, id);
}
}

drop(handle0);
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use iris_mpc_common::{
galois_engine::degree4::GaloisRingIrisCodeShare,
helpers::{
key_pair::download_public_key,
sha256::calculate_sha256,
sha256::sha256_as_hex_string,
smpc_request::{IrisCodesJSON, UniquenessRequest, UNIQUENESS_MESSAGE_TYPE},
smpc_response::{create_message_type_attribute_map, UniquenessResult},
sqs_s3_helper::upload_file_and_generate_presigned_url,
Expand Down Expand Up @@ -339,7 +339,7 @@ async fn main() -> eyre::Result<()> {
.clone();

// calculate hash of the object
let hash_string = calculate_sha256(&serialized_iris_codes_json);
let hash_string = sha256_as_hex_string(&serialized_iris_codes_json);

// encrypt the object using sealed box and public key
let encrypted_bytes = sealedbox::seal(
Expand Down
Loading