Skip to content

Commit

Permalink
Send identity deletion results to sns (#335)
Browse files Browse the repository at this point in the history
* add msg attribute to uniqueness results

* send identity deletion results to sns

* rename message attribute constants

* return deleted ids from actor and send deletion results in result sender thread

* fmt
  • Loading branch information
eaypek-tfh authored Sep 9, 2024
1 parent 29c54fe commit 112121f
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 25 deletions.
43 changes: 37 additions & 6 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{key_pair::SharesDecodingError, sha256::calculate_sha256};
use crate::helpers::key_pair::SharesEncryptionKeyPairs;
use aws_sdk_sns::types::MessageAttributeValue;
use aws_sdk_sqs::{
error::SdkError,
operation::{delete_message::DeleteMessageError, receive_message::ReceiveMessageError},
Expand All @@ -8,7 +9,7 @@ use base64::{engine::general_purpose::STANDARD, Engine};
use eyre::Report;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::sync::LazyLock;
use std::{collections::HashMap, sync::LazyLock};
use thiserror::Error;
use tokio_retry::{
strategy::{jitter, FixedInterval},
Expand All @@ -33,9 +34,9 @@ pub struct SQSMessage {
pub unsubscribe_url: String,
}

pub const SMPC_REQUEST_TYPE_ATTRIBUTE: &str = "message_type";
pub const IDENTITY_DELETION_REQUEST_TYPE: &str = "identity_deletion";
pub const UNIQUENESS_REQUEST_TYPE: &str = "uniqueness";
pub const SMPC_MESSAGE_TYPE_ATTRIBUTE: &str = "message_type";
pub const IDENTITY_DELETION_MESSAGE_TYPE: &str = "identity_deletion";
pub const UNIQUENESS_MESSAGE_TYPE: &str = "uniqueness";

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct UniquenessRequest {
Expand Down Expand Up @@ -221,15 +222,15 @@ impl UniquenessRequest {
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ResultEvent {
pub struct UniquenessResult {
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
}

impl ResultEvent {
impl UniquenessResult {
pub fn new(
node_id: usize,
serial_id: Option<u32>,
Expand All @@ -246,3 +247,33 @@ impl ResultEvent {
}
}
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct IdentityDeletionResult {
pub node_id: usize,
pub serial_id: u32,
pub success: bool,
}

impl IdentityDeletionResult {
pub fn new(node_id: usize, serial_id: u32, success: bool) -> Self {
Self {
node_id,
serial_id,
success,
}
}
}

pub fn create_message_type_attribute_map(
message_type: &str,
) -> HashMap<String, MessageAttributeValue> {
let mut message_attributes_map = HashMap::new();
let message_type_value = MessageAttributeValue::builder()
.data_type("String")
.string_value(message_type)
.build()
.unwrap();
message_attributes_map.insert(SMPC_MESSAGE_TYPE_ATTRIBUTE.to_string(), message_type_value);
message_attributes_map
}
1 change: 1 addition & 0 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ impl ServerActor {
match_ids,
store_left: query_store_left,
store_right: query_store_right,
deleted_ids: batch.deletion_requests,
})
.unwrap();

Expand Down
1 change: 1 addition & 0 deletions iris-mpc-gpu/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pub struct ServerJobResult {
pub match_ids: Vec<Vec<u32>>,
pub store_left: BatchQueryEntries,
pub store_right: BatchQueryEntries,
pub deleted_ids: Vec<u32>,
}

enum Eye {
Expand Down
4 changes: 2 additions & 2 deletions iris-mpc-store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ mod tests {

use super::*;
use futures::TryStreamExt;
use iris_mpc_common::helpers::smpc_request::ResultEvent;
use iris_mpc_common::helpers::smpc_request::UniquenessResult;

#[tokio::test]
#[cfg(feature = "db_dependent")]
Expand Down Expand Up @@ -493,7 +493,7 @@ mod tests {
};
let codes_and_masks = vec![iris; count];

let result_event = serde_json::to_string(&ResultEvent::new(
let result_event = serde_json::to_string(&UniquenessResult::new(
0,
Some(1_000_000_000),
false,
Expand Down
7 changes: 4 additions & 3 deletions iris-mpc/src/bin/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use iris_mpc_common::{
helpers::{
key_pair::download_public_key,
sha256::calculate_sha256,
smpc_request::{IrisCodesJSON, ResultEvent, UniquenessRequest},
smpc_request::{IrisCodesJSON, UniquenessRequest, UniquenessResult},
sqs_s3_helper::upload_file_and_generate_presigned_url,
},
iris_db::{db::IrisDB, iris::IrisCode},
Expand Down Expand Up @@ -145,8 +145,9 @@ async fn main() -> eyre::Result<()> {
for msg in msg.messages.unwrap_or_default() {
counter += 1;

let result: ResultEvent = serde_json::from_str(&msg.body.context("No body found")?)
.context("Failed to parse message body")?;
let result: UniquenessResult =
serde_json::from_str(&msg.body.context("No body found")?)
.context("Failed to parse message body")?;

println!("Received result: {:?}", result);

Expand Down
66 changes: 52 additions & 14 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#![allow(clippy::needless_range_loop)]

use aws_sdk_sns::Client as SNSClient;
use aws_sdk_sns::{types::MessageAttributeValue, Client as SNSClient};
use aws_sdk_sqs::{config::Region, Client};
use axum::{routing::get, Router};
use clap::Parser;
Expand All @@ -14,9 +14,9 @@ use iris_mpc_common::{
key_pair::SharesEncryptionKeyPairs,
kms_dh::derive_shared_secret,
smpc_request::{
IdentityDeletionRequest, ReceiveRequestError, ResultEvent, SQSMessage,
UniquenessRequest, IDENTITY_DELETION_REQUEST_TYPE, SMPC_REQUEST_TYPE_ATTRIBUTE,
UNIQUENESS_REQUEST_TYPE,
create_message_type_attribute_map, IdentityDeletionRequest, IdentityDeletionResult,
ReceiveRequestError, SQSMessage, UniquenessRequest, UniquenessResult,
IDENTITY_DELETION_MESSAGE_TYPE, SMPC_MESSAGE_TYPE_ATTRIBUTE, UNIQUENESS_MESSAGE_TYPE,
},
sync::SyncState,
task_monitor::TaskMonitor,
Expand All @@ -34,6 +34,7 @@ use iris_mpc_gpu::{
use iris_mpc_store::{Store, StoredIrisRef};
use rand::{rngs::StdRng, SeedableRng};
use std::{
collections::HashMap,
mem,
sync::{Arc, LazyLock, Mutex},
time::{Duration, Instant},
Expand Down Expand Up @@ -151,13 +152,13 @@ async fn receive_batch(
}

let request_type = message_attributes
.get(SMPC_REQUEST_TYPE_ATTRIBUTE)
.get(SMPC_MESSAGE_TYPE_ATTRIBUTE)
.ok_or(ReceiveRequestError::NoMessageTypeAttribute)?
.string_value()
.ok_or(ReceiveRequestError::NoMessageTypeAttribute)?;

match request_type {
IDENTITY_DELETION_REQUEST_TYPE => {
IDENTITY_DELETION_MESSAGE_TYPE => {
// If it's a deletion request, we just store the serial_id and continue.
// Deletion will take place when batch process starts.
let identity_deletion_request: IdentityDeletionRequest =
Expand All @@ -179,7 +180,7 @@ async fn receive_batch(
.await
.map_err(ReceiveRequestError::FailedToDeleteFromSQS)?;
}
UNIQUENESS_REQUEST_TYPE => {
UNIQUENESS_MESSAGE_TYPE => {
msg_counter += 1;

let shares_encryption_key_pairs = shares_encryption_key_pairs.clone();
Expand Down Expand Up @@ -517,17 +518,19 @@ async fn initialize_iris_dbs(
))
}

async fn send_result_events(
async fn send_results_to_sns(
result_events: Vec<String>,
sns_client: &SNSClient,
config: &Config,
message_attributes: &HashMap<String, MessageAttributeValue>,
) -> eyre::Result<()> {
for result_event in result_events {
sns_client
.publish()
.topic_arn(&config.results_topic_arn)
.message(result_event)
.message_group_id(format!("party-id-{}", config.party_id))
.set_message_attributes(Some(message_attributes.clone()))
.send()
.await?;
}
Expand Down Expand Up @@ -594,11 +597,15 @@ async fn server_main(config: Config) -> eyre::Result<()> {
tracing::info!("Deriving shared secrets");
let chacha_seeds = initialize_chacha_seeds(&config.kms_key_arns, party_id).await?;

let uniqueness_result_attributes = create_message_type_attribute_map(UNIQUENESS_MESSAGE_TYPE);
let identity_deletion_result_attributes =
create_message_type_attribute_map(IDENTITY_DELETION_MESSAGE_TYPE);
tracing::info!("Replaying results");
send_result_events(
send_results_to_sns(
store.last_results(max_sync_lookback).await?,
&sns_client,
&config,
&uniqueness_result_attributes,
)
.await?;

Expand Down Expand Up @@ -719,14 +726,15 @@ async fn server_main(config: Config) -> eyre::Result<()> {
match_ids,
store_left,
store_right,
deleted_ids,
}) = rx.recv().await
{
// returned serial_ids are 0 indexed, but we want them to be 1 indexed
let result_events = merged_results
let uniqueness_results = merged_results
.iter()
.enumerate()
.map(|(i, &idx_result)| {
let result_event = ResultEvent::new(
let result_event = UniquenessResult::new(
party_id,
match matches[i] {
true => None,
Expand Down Expand Up @@ -765,7 +773,9 @@ async fn server_main(config: Config) -> eyre::Result<()> {

let mut tx = store_bg.tx().await?;

store_bg.insert_results(&mut tx, &result_events).await?;
store_bg
.insert_results(&mut tx, &uniqueness_results)
.await?;

if !codes_and_masks.is_empty() {
store_bg
Expand All @@ -776,8 +786,36 @@ async fn server_main(config: Config) -> eyre::Result<()> {

tx.commit().await?;

tracing::info!("Sending {} results", result_events.len());
send_result_events(result_events, &sns_client_bg, &config_bg).await?;
tracing::info!("Sending {} uniqueness results", uniqueness_results.len());
send_results_to_sns(
uniqueness_results,
&sns_client_bg,
&config_bg,
&uniqueness_result_attributes,
)
.await?;

// handling identity deletion results
let identity_deletion_results = deleted_ids
.iter()
.map(|serial_id| {
let result_event = IdentityDeletionResult::new(party_id, *serial_id, true);
serde_json::to_string(&result_event)
.wrap_err("failed to serialize identity deletion result")
})
.collect::<eyre::Result<Vec<_>>>()?;

tracing::info!(
"Sending {} identity deletion results",
identity_deletion_results.len()
);
send_results_to_sns(
identity_deletion_results,
&sns_client_bg,
&config_bg,
&identity_deletion_result_attributes,
)
.await?;
}

Ok(())
Expand Down

0 comments on commit 112121f

Please sign in to comment.