diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs index 5dc7677e8..e931ece1a 100644 --- a/iris-mpc-common/src/helpers/smpc_request.rs +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -1,5 +1,6 @@ use super::{key_pair::SharesDecodingError, sha256::calculate_sha256}; use crate::helpers::key_pair::SharesEncryptionKeyPairs; +use aws_sdk_sns::types::MessageAttributeValue; use aws_sdk_sqs::{ error::SdkError, operation::{delete_message::DeleteMessageError, receive_message::ReceiveMessageError}, @@ -8,7 +9,7 @@ use base64::{engine::general_purpose::STANDARD, Engine}; use eyre::Report; use reqwest::Client; use serde::{Deserialize, Serialize}; -use std::sync::LazyLock; +use std::{collections::HashMap, sync::LazyLock}; use thiserror::Error; use tokio_retry::{ strategy::{jitter, FixedInterval}, @@ -33,9 +34,9 @@ pub struct SQSMessage { pub unsubscribe_url: String, } -pub const SMPC_REQUEST_TYPE_ATTRIBUTE: &str = "message_type"; -pub const IDENTITY_DELETION_REQUEST_TYPE: &str = "identity_deletion"; -pub const UNIQUENESS_REQUEST_TYPE: &str = "uniqueness"; +pub const SMPC_MESSAGE_TYPE_ATTRIBUTE: &str = "message_type"; +pub const IDENTITY_DELETION_MESSAGE_TYPE: &str = "identity_deletion"; +pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness"; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct UniquenessRequest { @@ -221,7 +222,7 @@ impl UniquenessRequest { } #[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ResultEvent { +pub struct UniquenessResult { pub node_id: usize, pub serial_id: Option, pub is_match: bool, @@ -229,7 +230,7 @@ pub struct ResultEvent { pub matched_serial_ids: Option>, } -impl ResultEvent { +impl UniquenessResult { pub fn new( node_id: usize, serial_id: Option, @@ -246,3 +247,33 @@ impl ResultEvent { } } } + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct IdentityDeletionResult { + pub node_id: usize, + pub serial_id: u32, + pub success: bool, +} + +impl IdentityDeletionResult { + pub fn new(node_id: usize, serial_id: u32, success: bool) -> Self { + Self { + node_id, + serial_id, + success, + } + } +} + +pub fn create_message_type_attribute_map( + message_type: &str, +) -> HashMap { + let mut message_attributes_map = HashMap::new(); + let message_type_value = MessageAttributeValue::builder() + .data_type("String") + .string_value(message_type) + .build() + .unwrap(); + message_attributes_map.insert(SMPC_MESSAGE_TYPE_ATTRIBUTE.to_string(), message_type_value); + message_attributes_map +} diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 1ac821923..d699a356b 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -777,6 +777,7 @@ impl ServerActor { match_ids, store_left: query_store_left, store_right: query_store_right, + deleted_ids: batch.deletion_requests, }) .unwrap(); diff --git a/iris-mpc-gpu/src/server/mod.rs b/iris-mpc-gpu/src/server/mod.rs index e2ca23953..07ac82b7a 100644 --- a/iris-mpc-gpu/src/server/mod.rs +++ b/iris-mpc-gpu/src/server/mod.rs @@ -95,6 +95,7 @@ pub struct ServerJobResult { pub match_ids: Vec>, pub store_left: BatchQueryEntries, pub store_right: BatchQueryEntries, + pub deleted_ids: Vec, } enum Eye { diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index ef5cc73db..f373ddf4c 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -398,7 +398,7 @@ mod tests { use super::*; use futures::TryStreamExt; - use iris_mpc_common::helpers::smpc_request::ResultEvent; + use iris_mpc_common::helpers::smpc_request::UniquenessResult; #[tokio::test] #[cfg(feature = "db_dependent")] @@ -493,7 +493,7 @@ mod tests { }; let codes_and_masks = vec![iris; count]; - let result_event = serde_json::to_string(&ResultEvent::new( + let result_event = serde_json::to_string(&UniquenessResult::new( 0, Some(1_000_000_000), false, diff --git a/iris-mpc/src/bin/client.rs b/iris-mpc/src/bin/client.rs index 3e6217a01..d10c2f65e 100644 --- a/iris-mpc/src/bin/client.rs +++ b/iris-mpc/src/bin/client.rs @@ -9,7 +9,7 @@ use iris_mpc_common::{ helpers::{ key_pair::download_public_key, sha256::calculate_sha256, - smpc_request::{IrisCodesJSON, ResultEvent, UniquenessRequest}, + smpc_request::{IrisCodesJSON, UniquenessRequest, UniquenessResult}, sqs_s3_helper::upload_file_and_generate_presigned_url, }, iris_db::{db::IrisDB, iris::IrisCode}, @@ -145,8 +145,9 @@ async fn main() -> eyre::Result<()> { for msg in msg.messages.unwrap_or_default() { counter += 1; - let result: ResultEvent = serde_json::from_str(&msg.body.context("No body found")?) - .context("Failed to parse message body")?; + let result: UniquenessResult = + serde_json::from_str(&msg.body.context("No body found")?) + .context("Failed to parse message body")?; println!("Received result: {:?}", result); diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index f645ddd25..18ffb219b 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -1,6 +1,6 @@ #![allow(clippy::needless_range_loop)] -use aws_sdk_sns::Client as SNSClient; +use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient}; use aws_sdk_sqs::{config::Region, Client}; use axum::{routing::get, Router}; use clap::Parser; @@ -14,9 +14,9 @@ use iris_mpc_common::{ key_pair::SharesEncryptionKeyPairs, kms_dh::derive_shared_secret, smpc_request::{ - IdentityDeletionRequest, ReceiveRequestError, ResultEvent, SQSMessage, - UniquenessRequest, IDENTITY_DELETION_REQUEST_TYPE, SMPC_REQUEST_TYPE_ATTRIBUTE, - UNIQUENESS_REQUEST_TYPE, + create_message_type_attribute_map, IdentityDeletionRequest, IdentityDeletionResult, + ReceiveRequestError, SQSMessage, UniquenessRequest, UniquenessResult, + IDENTITY_DELETION_MESSAGE_TYPE, SMPC_MESSAGE_TYPE_ATTRIBUTE, UNIQUENESS_MESSAGE_TYPE, }, sync::SyncState, task_monitor::TaskMonitor, @@ -34,6 +34,7 @@ use iris_mpc_gpu::{ use iris_mpc_store::{Store, StoredIrisRef}; use rand::{rngs::StdRng, SeedableRng}; use std::{ + collections::HashMap, mem, sync::{Arc, LazyLock, Mutex}, time::{Duration, Instant}, @@ -151,13 +152,13 @@ async fn receive_batch( } let request_type = message_attributes - .get(SMPC_REQUEST_TYPE_ATTRIBUTE) + .get(SMPC_MESSAGE_TYPE_ATTRIBUTE) .ok_or(ReceiveRequestError::NoMessageTypeAttribute)? .string_value() .ok_or(ReceiveRequestError::NoMessageTypeAttribute)?; match request_type { - IDENTITY_DELETION_REQUEST_TYPE => { + IDENTITY_DELETION_MESSAGE_TYPE => { // If it's a deletion request, we just store the serial_id and continue. // Deletion will take place when batch process starts. let identity_deletion_request: IdentityDeletionRequest = @@ -179,7 +180,7 @@ async fn receive_batch( .await .map_err(ReceiveRequestError::FailedToDeleteFromSQS)?; } - UNIQUENESS_REQUEST_TYPE => { + UNIQUENESS_MESSAGE_TYPE => { msg_counter += 1; let shares_encryption_key_pairs = shares_encryption_key_pairs.clone(); @@ -517,10 +518,11 @@ async fn initialize_iris_dbs( )) } -async fn send_result_events( +async fn send_results_to_sns( result_events: Vec, sns_client: &SNSClient, config: &Config, + message_attributes: &HashMap, ) -> eyre::Result<()> { for result_event in result_events { sns_client @@ -528,6 +530,7 @@ async fn send_result_events( .topic_arn(&config.results_topic_arn) .message(result_event) .message_group_id(format!("party-id-{}", config.party_id)) + .set_message_attributes(Some(message_attributes.clone())) .send() .await?; } @@ -594,11 +597,15 @@ async fn server_main(config: Config) -> eyre::Result<()> { tracing::info!("Deriving shared secrets"); let chacha_seeds = initialize_chacha_seeds(&config.kms_key_arns, party_id).await?; + let uniqueness_result_attributes = create_message_type_attribute_map(UNIQUENESS_MESSAGE_TYPE); + let identity_deletion_result_attributes = + create_message_type_attribute_map(IDENTITY_DELETION_MESSAGE_TYPE); tracing::info!("Replaying results"); - send_result_events( + send_results_to_sns( store.last_results(max_sync_lookback).await?, &sns_client, &config, + &uniqueness_result_attributes, ) .await?; @@ -719,14 +726,15 @@ async fn server_main(config: Config) -> eyre::Result<()> { match_ids, store_left, store_right, + deleted_ids, }) = rx.recv().await { // returned serial_ids are 0 indexed, but we want them to be 1 indexed - let result_events = merged_results + let uniqueness_results = merged_results .iter() .enumerate() .map(|(i, &idx_result)| { - let result_event = ResultEvent::new( + let result_event = UniquenessResult::new( party_id, match matches[i] { true => None, @@ -765,7 +773,9 @@ async fn server_main(config: Config) -> eyre::Result<()> { let mut tx = store_bg.tx().await?; - store_bg.insert_results(&mut tx, &result_events).await?; + store_bg + .insert_results(&mut tx, &uniqueness_results) + .await?; if !codes_and_masks.is_empty() { store_bg @@ -776,8 +786,36 @@ async fn server_main(config: Config) -> eyre::Result<()> { tx.commit().await?; - tracing::info!("Sending {} results", result_events.len()); - send_result_events(result_events, &sns_client_bg, &config_bg).await?; + tracing::info!("Sending {} uniqueness results", uniqueness_results.len()); + send_results_to_sns( + uniqueness_results, + &sns_client_bg, + &config_bg, + &uniqueness_result_attributes, + ) + .await?; + + // handling identity deletion results + let identity_deletion_results = deleted_ids + .iter() + .map(|serial_id| { + let result_event = IdentityDeletionResult::new(party_id, *serial_id, true); + serde_json::to_string(&result_event) + .wrap_err("failed to serialize identity deletion result") + }) + .collect::>>()?; + + tracing::info!( + "Sending {} identity deletion results", + identity_deletion_results.len() + ); + send_results_to_sns( + identity_deletion_results, + &sns_client_bg, + &config_bg, + &identity_deletion_result_attributes, + ) + .await?; } Ok(())