diff --git a/.env.mpc1.dist b/.env.mpc1.dist index 776f0b47b..d129f7b5e 100644 --- a/.env.mpc1.dist +++ b/.env.mpc1.dist @@ -22,6 +22,7 @@ SMPC__PARTY_ID=0 SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/654654380399/mpc1.fifo SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:654654380399:mpc-results-topic SMPC__PROCESSING_TIMEOUT_SECS=60 +SMPC__PUBLIC_KEY_BASE_URL=https://d24uxaabh702ht.cloudfront.net # These can be either ARNs or IDs, in production multi account setup they are ARNs SMPC__KMS_KEY_ARNS='["077788e2-9eeb-4044-859b-34496cfd500b", "896353dc-5ea5-42d4-9e4e-f65dd8169dee", "42bb01f5-8380-48b4-b1f1-929463a587fb"]' diff --git a/.env.mpc2.dist b/.env.mpc2.dist index 3c0c676c0..5b25b36d1 100644 --- a/.env.mpc2.dist +++ b/.env.mpc2.dist @@ -22,6 +22,7 @@ SMPC__PARTY_ID=1 SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/654654380399/mpc2.fifo SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:654654380399:mpc-results-topic SMPC__PROCESSING_TIMEOUT_SECS=60 +SMPC__PUBLIC_KEY_BASE_URL=https://d24uxaabh702ht.cloudfront.net # These can be either ARNs or IDs, in production multi account setup they are ARNs SMPC__KMS_KEY_ARNS='["077788e2-9eeb-4044-859b-34496cfd500b", "896353dc-5ea5-42d4-9e4e-f65dd8169dee", "42bb01f5-8380-48b4-b1f1-929463a587fb"]' diff --git a/.env.mpc3.dist b/.env.mpc3.dist index f7a58312d..4a7ab426a 100644 --- a/.env.mpc3.dist +++ b/.env.mpc3.dist @@ -22,6 +22,7 @@ SMPC__PARTY_ID=2 SMPC__REQUESTS_QUEUE_URL=https://sqs.eu-north-1.amazonaws.com/654654380399/mpc3.fifo SMPC__RESULTS_TOPIC_ARN=arn:aws:sns:eu-north-1:654654380399:mpc-results-topic SMPC__PROCESSING_TIMEOUT_SECS=60 +SMPC__PUBLIC_KEY_BASE_URL=https://d24uxaabh702ht.cloudfront.net # These can be either ARNs or IDs, in production multi account setup they are ARNs SMPC__KMS_KEY_ARNS='["077788e2-9eeb-4044-859b-34496cfd500b", "896353dc-5ea5-42d4-9e4e-f65dd8169dee", "42bb01f5-8380-48b4-b1f1-929463a587fb"]' diff --git a/Cargo.lock b/Cargo.lock index 2f2b1cdee..5dc448a78 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -115,6 +115,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-trait" version = "0.1.81" @@ -1239,6 +1249,24 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +[[package]] +name = "deadpool" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb84100978c1c7b37f09ed3ce3e5f843af02c2a2c431bae5b19230dad2c1b490" +dependencies = [ + "async-trait", + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" + [[package]] name = "der" version = "0.6.1" @@ -2209,6 +2237,8 @@ dependencies = [ "ndarray", "rand", "serde_json", + "sha2", + "sodiumoxide", "static_assertions", "telemetry-batteries", "tokio", @@ -2252,6 +2282,8 @@ dependencies = [ "tracing", "tracing-subscriber", "url", + "wiremock", + "zeroize", ] [[package]] @@ -2259,6 +2291,7 @@ name = "iris-mpc-gpu" version = "0.1.0" dependencies = [ "axum", + "base64 0.22.1", "bytemuck", "criterion", "cudarc", @@ -2276,6 +2309,7 @@ dependencies = [ "ring", "serde", "serde_json", + "sodiumoxide", "static_assertions", "tokio", "tracing", @@ -5423,6 +5457,30 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "wiremock" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a59f8ae78a4737fb724f20106fb35ccb7cfe61ff335665d3042b3aa98e34717" +dependencies = [ + "assert-json-diff", + "async-trait", + "base64 0.21.7", + "deadpool", + "futures", + "http 1.1.0", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "log", + "once_cell", + "regex", + "serde", + "serde_json", + "tokio", + "url", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 89e329432..e0de04056 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ tracing = "0.1.40" tracing-subscriber = { version = "0.3.15", features = ["env-filter"] } rand = "0.8" rayon = "1.5.1" -reqwest = { version = "0.12", features = ["blocking"] } +reqwest = { version = "0.12", features = ["blocking", "json"] } static_assertions = "1.1" telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "802a4f39f358e077b11c8429b4c65f3e45b85959" } tokio = { version = "1.39", features = ["full", "rt-multi-thread"] } diff --git a/Dockerfile.debug b/Dockerfile.debug new file mode 100644 index 000000000..83758e93c --- /dev/null +++ b/Dockerfile.debug @@ -0,0 +1,36 @@ +FROM --platform=linux/amd64 ubuntu:22.04 +ENV DEBIAN_FRONTEND=noninteractive + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + curl \ + build-essential \ + libssl-dev \ + texinfo \ + libcap2-bin \ + pkg-config \ + git \ + devscripts \ + debhelper \ + ca-certificates \ + wget + +RUN curl https://sh.rustup.rs -sSf | sh -s -- -y +ENV PATH "/root/.cargo/bin:${PATH}" +ENV RUSTUP_HOME "/root/.rustup" +ENV CARGO_HOME "/root/.cargo" +RUN rustup toolchain install nightly-2024-07-10 +RUN rustup default nightly-2024-07-10 +RUN rustup component add cargo +RUN cargo install cargo-build-deps && cargo install cargo-edit + +COPY . . + +RUN apt-get update && apt-get install -y pkg-config wget libssl-dev ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb \ + && dpkg -i cuda-keyring_1.1-1_all.deb \ + && apt-get update \ + && apt-get install -y cuda-toolkit-12-2 libnccl2=2.22.3-1+cuda12.2 libnccl-dev=2.22.3-1+cuda12.2 diff --git a/deploy/README.md b/deploy/README.md index 56cd97f53..d6ad32152 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -1,3 +1,13 @@ +# GPU Iris MPC Deployment in current stage + +The application right now has issues with DB loading. To run the app it is necessary to truncate tables in the dbs in all 3 parties. + +To do so, please deploy the pod in `/deploy/db-cleaner-helper-pod.yaml` in every party and run the following command putting appropriate DB URL and party id in it beforehand: + +```bash +apt update && apt install -y postgresql-client && psql -H -c 'SET search_path TO "SMPC_stage_{0,1,2}"; TRUNCATE irises, results, sync;' +``` + # Application Upgrade Documentation This document provides a step-by-step guide on how to upgrade the application deployed using ArgoCD. The application configuration is located in the `deploy/stage/mpc1-stage`, `deploy/stage/mpc2-stage`, and `deploy/stage/mpc3-stage` directories. Each directory contains a `values-gpu-iris-mpc.yaml` file, which includes the deployment configuration for the respective clusters: `mpc1-stage`, `mpc2-stage`, and `mpc3-stage`, and common value file placed in `deploy/stage/values-common-gpu-iris-mpc.yaml` diff --git a/temp.yaml b/deploy/db-cleaner-helper-pod.yaml similarity index 63% rename from temp.yaml rename to deploy/db-cleaner-helper-pod.yaml index faa5e9c45..b834ece9d 100644 --- a/temp.yaml +++ b/deploy/db-cleaner-helper-pod.yaml @@ -1,7 +1,7 @@ apiVersion: v1 kind: Pod metadata: - name: gpu-iris-mpc + name: db-cleaner namespace: gpu-iris-mpc spec: hostNetwork: true @@ -16,17 +16,15 @@ spec: securityContext: runAsUser: 0 containers: - - name: gpu-iris-mpc - image: ghcr.io/worldcoin/gpu-iris-mpc-debug:2b1a5adf8aebc6a0917f591ec2cc364d5ea5d346 + - name: db-cleaner + image: ghcr.io/worldcoin/gpu-iris-mpc-debug:34b305f6e9acafe9043636fb32fc11870615f34e imagePullPolicy: Always command: [ "/bin/bash" ] args: [ "-c", "while true; do ping localhost; sleep 60; done" ] resources: limits: - nvidia.com/gpu: 1 - vpc.amazonaws.com/efa: 1 + cpu: 1 + memory: 1Gi requests: - nvidia.com/gpu: 1 - vpc.amazonaws.com/efa: 1 - ports: - - containerPort: 3000 + cpu: 1 + memory: 1Gi diff --git a/deploy/orb-stage-helper-pod.yaml b/deploy/orb-stage-helper-pod.yaml new file mode 100644 index 000000000..14fd85541 --- /dev/null +++ b/deploy/orb-stage-helper-pod.yaml @@ -0,0 +1,29 @@ +apiVersion: v1 +kind: Pod +metadata: + name: smpcv2-signup-service-helper + namespace: signup-service +spec: + serviceAccountName: signup-service-worker # Add this line + imagePullSecrets: + - name: github-secret + nodeSelector: + kubernetes.io/arch: amd64 + containers: + - name: smpcv2-signup-service-helper + image: ghcr.io/worldcoin/gpu-iris-mpc-debug:0510757a9d076c206d9a42eedca639787c44a0a8 + securityContext: + runAsUser: 0 + allowPrivilegeEscalation: false + seccompProfile: + type: RuntimeDefault # or Localhost if you have a local profile + imagePullPolicy: Always + command: [ "/bin/bash" ] + args: [ "-c", "while true; do ping localhost; sleep 60; done" ] + resources: + limits: + cpu: 4 + memory: 4Gi + requests: + cpu: 4 + memory: 4Gi diff --git a/deploy/stage/common-values-gpu-iris-mpc.yaml b/deploy/stage/common-values-gpu-iris-mpc.yaml index 83a2511ec..9b54d8b80 100644 --- a/deploy/stage/common-values-gpu-iris-mpc.yaml +++ b/deploy/stage/common-values-gpu-iris-mpc.yaml @@ -1,4 +1,4 @@ -image: "ghcr.io/worldcoin/gpu-iris-mpc:debug-pop-1768-v2" +image: "ghcr.io/worldcoin/gpu-iris-mpc:v0.3.0-alpha" environment: stage replicaCount: 1 diff --git a/deploy/stage/mpc1-stage/values-gpu-iris-mpc.yaml b/deploy/stage/mpc1-stage/values-gpu-iris-mpc.yaml index 550855d4b..8ca116584 100644 --- a/deploy/stage/mpc1-stage/values-gpu-iris-mpc.yaml +++ b/deploy/stage/mpc1-stage/values-gpu-iris-mpc.yaml @@ -62,5 +62,5 @@ env: - name: SMPC__PARTY_ID value: "0" - - name: SMPC__PUBLIC_KEY_BUCKET_NAME - value: "wf-smpcv2-stage-public-keys" + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "https://d24uxaabh702ht.cloudfront.net" diff --git a/deploy/stage/mpc2-stage/values-gpu-iris-mpc.yaml b/deploy/stage/mpc2-stage/values-gpu-iris-mpc.yaml index 03a44f409..2370ca7e2 100644 --- a/deploy/stage/mpc2-stage/values-gpu-iris-mpc.yaml +++ b/deploy/stage/mpc2-stage/values-gpu-iris-mpc.yaml @@ -62,5 +62,5 @@ env: - name: SMPC__PARTY_ID value: "1" - - name: SMPC__PUBLIC_KEY_BUCKET_NAME - value: "wf-smpcv2-stage-public-keys" + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "https://d24uxaabh702ht.cloudfront.net" diff --git a/deploy/stage/mpc3-stage/values-gpu-iris-mpc.yaml b/deploy/stage/mpc3-stage/values-gpu-iris-mpc.yaml index 25eec0aae..26828aaad 100644 --- a/deploy/stage/mpc3-stage/values-gpu-iris-mpc.yaml +++ b/deploy/stage/mpc3-stage/values-gpu-iris-mpc.yaml @@ -62,5 +62,5 @@ env: - name: SMPC__PARTY_ID value: "2" - - name: SMPC__PUBLIC_KEY_BUCKET_NAME - value: "wf-smpcv2-stage-public-keys" + - name: SMPC__PUBLIC_KEY_BASE_URL + value: "https://d24uxaabh702ht.cloudfront.net" diff --git a/iris-mpc-common/Cargo.toml b/iris-mpc-common/Cargo.toml index 398e4bf1b..172399911 100644 --- a/iris-mpc-common/Cargo.toml +++ b/iris-mpc-common/Cargo.toml @@ -24,8 +24,8 @@ config = "0.14.0" tokio.workspace = true tracing.workspace = true tracing-subscriber.workspace = true -reqwest.workspace = true +reqwest = { workspace = true, features = ["blocking", "json"] } sodiumoxide = "0.2.7" hmac = "0.12" http = "1.1.0" @@ -36,6 +36,8 @@ sha2 = "0.10" time = { version = "^0.3.6", features = ["formatting", "macros"] } url = "2" hex.workspace = true +zeroize = "1.8.1" +wiremock = "0.6.1" [dev-dependencies] float_eq = "1" diff --git a/iris-mpc-common/src/bin/key_manager.rs b/iris-mpc-common/src/bin/key_manager.rs index 28a1aebd6..7422b1154 100644 --- a/iris-mpc-common/src/bin/key_manager.rs +++ b/iris-mpc-common/src/bin/key_manager.rs @@ -13,7 +13,7 @@ use rand::{rngs::StdRng, Rng, SeedableRng}; use reqwest::Client; use sodiumoxide::crypto::box_::{curve25519xsalsa20poly1305, PublicKey, SecretKey, Seed}; -const PUBLIC_KEY_S3_BUCKET_NAME: &str = "wf-mpc-vpc-stage-public-smpcv2-keys"; +const PUBLIC_KEY_S3_BUCKET_NAME: &str = "wf-smpcv2-stage-public-keys"; const PUBLIC_KEY_S3_KEY_NAME_PREFIX: &str = "public-key"; const REGION: &str = "eu-north-1"; @@ -318,6 +318,9 @@ mod test { let server_iris_code_plaintext = sealedbox::open(&ciphertext, &server_public_key, &server_private_key).unwrap(); - assert!(client_iris_code_plaintext.as_bytes() == server_iris_code_plaintext.as_slice()); + assert_eq!( + client_iris_code_plaintext.as_bytes(), + server_iris_code_plaintext.as_slice() + ); } } diff --git a/iris-mpc-common/src/config/mod.rs b/iris-mpc-common/src/config/mod.rs index 226f51cec..085d45d03 100644 --- a/iris-mpc-common/src/config/mod.rs +++ b/iris-mpc-common/src/config/mod.rs @@ -1,10 +1,10 @@ -pub mod json_wrapper; - use crate::config::json_wrapper::JsonStrWrapper; use clap::Parser; use serde::{Deserialize, Serialize}; use std::fmt; +pub mod json_wrapper; + #[derive(Debug, Parser)] pub struct Opt { #[structopt(long)] @@ -45,6 +45,9 @@ pub struct Config { #[serde(default = "default_processing_timeout_secs")] pub processing_timeout_secs: u64, + + #[serde(default)] + pub public_key_base_url: String, } fn default_processing_timeout_secs() -> u64 { diff --git a/iris-mpc-common/src/helpers/key_pair.rs b/iris-mpc-common/src/helpers/key_pair.rs new file mode 100644 index 000000000..434be4a75 --- /dev/null +++ b/iris-mpc-common/src/helpers/key_pair.rs @@ -0,0 +1,190 @@ +use crate::config::Config; +use aws_config::Region; +use aws_sdk_secretsmanager::{ + error::SdkError, operation::get_secret_value::GetSecretValueError, + Client as SecretsManagerClient, +}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use sodiumoxide::crypto::{ + box_::{PublicKey, SecretKey}, + sealedbox, +}; +use std::string::FromUtf8Error; +use thiserror::Error; +use zeroize::Zeroize; + +const REGION: &str = "eu-north-1"; +const CURRENT_SECRET_LABEL: &str = "AWSCURRENT"; + +#[derive(Error, Debug)] +pub enum SharesDecodingError { + #[error("Secrets Manager error: {0}")] + SecretsManagerError(#[from] SdkError), + #[error("Secret string not found")] + SecretStringNotFound, + #[error(transparent)] + RequestError(#[from] reqwest::Error), + #[error("Decoding error: {0}")] + DecodingError(#[from] base64::DecodeError), + #[error("Parsing bytes to UTF8 error")] + DecodedShareParsingToUTF8Error(#[from] FromUtf8Error), + #[error("Parsing key error")] + ParsingKeyError, + #[error("Sealed box open error")] + SealedBoxOpenError, + #[error("Public key not found error")] + PublicKeyNotFound, + #[error("Private key not found error")] + PrivateKeyNotFound, + #[error("Base64 decoding error")] + Base64DecodeError, + #[error("Received error message from server: [{}] {}", .status, .message)] + ResponseContent { + status: reqwest::StatusCode, + url: String, + message: String, + }, + #[error(transparent)] + SerdeError(#[from] serde_json::error::Error), + #[error(transparent)] + PresigningConfigError(#[from] aws_sdk_s3::presigning::PresigningConfigError), + #[error(transparent)] + PresignedRequestError( + #[from] aws_sdk_s3::error::SdkError, + ), +} + +#[derive(Clone, Debug)] +pub struct SharesEncryptionKeyPair { + pk: PublicKey, + sk: SecretKey, +} + +impl Zeroize for SharesEncryptionKeyPair { + fn zeroize(&mut self) { + self.pk.0.zeroize(); + self.sk.0.zeroize(); + } +} + +impl Drop for SharesEncryptionKeyPair { + fn drop(&mut self) { + self.pk.0.zeroize(); + self.sk.0.zeroize(); + } +} + +impl SharesEncryptionKeyPair { + pub async fn from_storage(config: Config) -> Result { + let region_provider = Region::new(REGION); + let shared_config = aws_config::from_env().region(region_provider).load().await; + let client = SecretsManagerClient::new(&shared_config); + + let pk_b64_string = match download_public_key( + config.public_key_base_url, + config.party_id.to_string(), + ) + .await + { + Ok(pk) => pk, + Err(e) => return Err(e), + }; + + let sk_b64_string = match download_private_key_from_asm( + &client, + &config.environment, + &config.party_id.to_string(), + CURRENT_SECRET_LABEL, + ) + .await + { + Ok(sk) => sk, + Err(e) => return Err(e), + }; + + match SharesEncryptionKeyPair::from_b64_strings(pk_b64_string, sk_b64_string) { + Ok(key_pair) => Ok(key_pair), + Err(e) => Err(e), + } + } + + pub fn from_b64_strings(pk: String, sk: String) -> Result { + let pk_bytes = match STANDARD.decode(pk) { + Ok(bytes) => bytes, + Err(e) => return Err(SharesDecodingError::DecodingError(e)), + }; + let sk_bytes = match STANDARD.decode(sk) { + Ok(bytes) => bytes, + Err(e) => return Err(SharesDecodingError::DecodingError(e)), + }; + + let pk = match PublicKey::from_slice(&pk_bytes) { + Some(pk) => pk, + None => return Err(SharesDecodingError::ParsingKeyError), + }; + let sk = match SecretKey::from_slice(&sk_bytes) { + Some(sk) => sk, + None => return Err(SharesDecodingError::ParsingKeyError), + }; + + Ok(Self { pk, sk }) + } + + pub fn open_sealed_box(&self, code: Vec) -> Result, SharesDecodingError> { + let decrypted = sealedbox::open(&code, &self.pk, &self.sk); + match decrypted { + Ok(bytes) => Ok(bytes), + Err(_) => Err(SharesDecodingError::SealedBoxOpenError), + } + } +} + +async fn download_private_key_from_asm( + client: &SecretsManagerClient, + env: &str, + node_id: &str, + version_stage: &str, +) -> Result { + let private_key_secret_id: String = + format!("{}/gpu-iris-mpc/ecdh-private-key-{}", env, node_id); + match client + .get_secret_value() + .secret_id(private_key_secret_id) + .version_stage(version_stage) + .send() + .await + { + Ok(secret_key_output) => match secret_key_output.secret_string { + Some(data) => Ok(data), + None => Err(SharesDecodingError::SecretStringNotFound), + }, + Err(e) => Err(e.into()), + } +} + +pub async fn download_public_key( + base_url: String, + node_id: String, +) -> Result { + let client = reqwest::Client::new(); + let url: String = format!("{}/public-key-{}", base_url, node_id); + let response = client.get(url.clone()).send().await; + match response { + Ok(response) => { + if response.status().is_success() { + let body = response.text().await; + match body { + Ok(body) => Ok(body), + Err(e) => Err(SharesDecodingError::RequestError(e)), + } + } else { + Err(SharesDecodingError::ResponseContent { + status: response.status(), + message: response.text().await.unwrap_or_default(), + url, + }) + } + } + Err(e) => Err(SharesDecodingError::RequestError(e)), + } +} diff --git a/iris-mpc-common/src/helpers/mod.rs b/iris-mpc-common/src/helpers/mod.rs index 4ac18f284..d77ee393c 100644 --- a/iris-mpc-common/src/helpers/mod.rs +++ b/iris-mpc-common/src/helpers/mod.rs @@ -1,6 +1,8 @@ pub mod aws; pub mod aws_sigv4; +pub mod key_pair; pub mod kms_dh; -pub mod sqs; +pub mod smpc_request; +pub mod sqs_s3_helper; pub mod sync; pub mod task_monitor; diff --git a/iris-mpc-common/src/helpers/smpc_request.rs b/iris-mpc-common/src/helpers/smpc_request.rs new file mode 100644 index 000000000..732801d00 --- /dev/null +++ b/iris-mpc-common/src/helpers/smpc_request.rs @@ -0,0 +1,161 @@ +use super::key_pair::{SharesDecodingError, SharesEncryptionKeyPair}; +use base64::{engine::general_purpose::STANDARD, Engine}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct SQSMessage { + #[serde(rename = "Type")] + pub notification_type: String, + #[serde(rename = "MessageId")] + pub message_id: String, + #[serde(rename = "SequenceNumber")] + pub sequence_number: String, + #[serde(rename = "TopicArn")] + pub topic_arn: String, + #[serde(rename = "Message")] + pub message: String, + #[serde(rename = "Timestamp")] + pub timestamp: String, + #[serde(rename = "UnsubscribeURL")] + pub unsubscribe_url: String, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct SMPCRequest { + pub batch_size: Option, + pub signup_id: String, + pub s3_presigned_url: String, + pub iris_shares_file_hashes: [String; 3], +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct SharesS3Object { + pub iris_share_0: String, + pub iris_share_1: String, + pub iris_share_2: String, +} + +#[derive(PartialEq, Serialize, Deserialize, Debug)] +pub struct IrisCodesJSON { + #[serde(rename = "IRIS_version")] + pub iris_version: String, + pub left_iris_code_shares: String, // these are base64 encoded strings + pub right_iris_code_shares: String, // these are base64 encoded strings + pub left_iris_mask_shares: String, // these are base64 encoded strings + pub right_iris_mask_shares: String, // these are base64 encoded strings +} + +impl SharesS3Object { + pub fn get(&self, party_id: usize) -> Option<&String> { + match party_id { + 0 => Some(&self.iris_share_0), + 1 => Some(&self.iris_share_1), + 2 => Some(&self.iris_share_2), + _ => None, + } + } +} + +impl SMPCRequest { + pub async fn get_iris_data_by_party_id( + &self, + party_id: usize, + ) -> Result { + // Send a GET request to the presigned URL + let response = match reqwest::get(self.s3_presigned_url.clone()).await { + Ok(response) => response, + Err(e) => { + tracing::error!("Failed to send request: {}", e); + return Err(SharesDecodingError::RequestError(e)); + } + }; + + // Ensure the request was successful + if response.status().is_success() { + // Parse the JSON response into the SharesS3Object struct + let shares_file: SharesS3Object = match response.json().await { + Ok(file) => file, + Err(e) => { + tracing::error!("Failed to parse JSON: {}", e); + return Err(SharesDecodingError::RequestError(e)); + } + }; + + // Construct the field name dynamically + let field_name = format!("iris_share_{}", party_id); + // Access the field dynamically + if let Some(value) = shares_file.get(party_id) { + Ok(value.to_string()) + } else { + tracing::error!("Failed to find field: {}", field_name); + Err(SharesDecodingError::SecretStringNotFound) + } + } else { + tracing::error!("Failed to download file: {}", response.status()); + Err(SharesDecodingError::ResponseContent { + status: response.status(), + url: self.s3_presigned_url.clone(), + message: response.text().await.unwrap_or_default(), + }) + } + } + + pub fn decrypt_iris_share( + &self, + share: String, + key_pair: SharesEncryptionKeyPair, + ) -> Result { + let share_bytes = STANDARD + .decode(share.as_bytes()) + .map_err(|_| SharesDecodingError::Base64DecodeError)?; + + let decrypted = key_pair.open_sealed_box(share_bytes); + + let iris_share = match decrypted { + Ok(bytes) => { + let json_string = String::from_utf8(bytes) + .map_err(SharesDecodingError::DecodedShareParsingToUTF8Error)?; + + tracing::info!("shares_json_string: {:?}", json_string); + let iris_share: IrisCodesJSON = + serde_json::from_str(&json_string).map_err(SharesDecodingError::SerdeError)?; + iris_share + } + Err(e) => return Err(e), + }; + + Ok(iris_share) + } + + #[allow(dead_code)] // TODO: implement hashes validation + fn validate_hashes(&self, hashes: [String; 3]) -> bool { + self.iris_shares_file_hashes == hashes + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ResultEvent { + pub node_id: usize, + pub serial_id: Option, + pub is_match: bool, + pub signup_id: String, + pub matched_serial_ids: Option>, +} + +impl ResultEvent { + pub fn new( + node_id: usize, + serial_id: Option, + is_match: bool, + signup_id: String, + matched_serial_ids: Option>, + ) -> Self { + Self { + node_id, + serial_id, + is_match, + signup_id, + matched_serial_ids, + } + } +} diff --git a/iris-mpc-common/src/helpers/sqs.rs b/iris-mpc-common/src/helpers/sqs.rs index 881e37410..e69de29bb 100644 --- a/iris-mpc-common/src/helpers/sqs.rs +++ b/iris-mpc-common/src/helpers/sqs.rs @@ -1,72 +0,0 @@ -use crate::iris_db::iris::IrisCodeArray; -use base64::{engine::general_purpose, Engine}; -use serde::{Deserialize, Serialize}; - -#[derive(Serialize, Deserialize, Debug)] -pub struct SQSMessage { - #[serde(rename = "Type")] - pub notification_type: String, - #[serde(rename = "MessageId")] - pub message_id: String, - #[serde(rename = "SequenceNumber")] - pub sequence_number: String, - #[serde(rename = "TopicArn")] - pub topic_arn: String, - #[serde(rename = "Message")] - pub message: String, - #[serde(rename = "Timestamp")] - pub timestamp: String, - #[serde(rename = "UnsubscribeURL")] - pub unsubscribe_url: String, -} - -#[derive(Serialize, Deserialize, Debug, Clone)] -pub struct SMPCRequest { - // TODO: make this a message attribute, but the SQS message will anyways be refactored soon. - pub batch_size: Option, - pub request_id: String, - pub iris_code: String, - pub mask_code: String, -} - -impl SMPCRequest { - fn decode_bytes(bytes: &[u8]) -> [u16; IrisCodeArray::IRIS_CODE_SIZE] { - let code = general_purpose::STANDARD.decode(bytes).unwrap(); - let mut buffer = [0u16; IrisCodeArray::IRIS_CODE_SIZE]; - buffer.copy_from_slice(bytemuck::cast_slice(&code)); - buffer - } - pub fn get_iris_shares(&self) -> [u16; IrisCodeArray::IRIS_CODE_SIZE] { - Self::decode_bytes(self.iris_code.as_bytes()) - } - pub fn get_mask_shares(&self) -> [u16; IrisCodeArray::IRIS_CODE_SIZE] { - Self::decode_bytes(self.mask_code.as_bytes()) - } -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -pub struct ResultEvent { - pub node_id: usize, - pub serial_id: Option, - pub is_match: bool, - pub signup_id: String, - pub matched_serial_ids: Option>, -} - -impl ResultEvent { - pub fn new( - node_id: usize, - serial_id: Option, - is_match: bool, - signup_id: String, - matched_serial_ids: Option>, - ) -> Self { - Self { - node_id, - serial_id, - is_match, - signup_id, - matched_serial_ids, - } - } -} diff --git a/iris-mpc-common/src/helpers/sqs_s3_helper.rs b/iris-mpc-common/src/helpers/sqs_s3_helper.rs new file mode 100644 index 000000000..a1f677e6e --- /dev/null +++ b/iris-mpc-common/src/helpers/sqs_s3_helper.rs @@ -0,0 +1,62 @@ +use crate::helpers::key_pair::SharesDecodingError; +use aws_config::meta::region::RegionProviderChain; +use aws_sdk_s3::{ + presigning::PresigningConfig, + primitives::{ByteStream, SdkBody}, + Client, +}; +use std::time::Duration; + +pub async fn upload_file_and_generate_presigned_url( + bucket: &str, + key: &str, + region: &'static str, + contents: &[u8], +) -> Result { + // Load AWS configuration + let region_provider = RegionProviderChain::first_try(region).or_default_provider(); + let config = aws_config::from_env().region(region_provider).load().await; + + // Create S3 client + let client = Client::new(&config); + let content_bytestream = ByteStream::new(SdkBody::from(contents)); + + // Create a PutObject request + match client + .put_object() + .bucket(bucket) + .key(key) + .body(content_bytestream) + .send() + .await + { + Ok(_) => { + tracing::info!("File uploaded successfully."); + } + Err(e) => { + tracing::error!("Error: Failed to upload file: {:?}", e); + } + } + + tracing::info!("File uploaded successfully."); + + // Create a presigned URL for the uploaded file + let presigning_config = match PresigningConfig::expires_in(Duration::from_secs(36000)) { + Ok(config) => config, + Err(e) => return Err(SharesDecodingError::PresigningConfigError(e)), + }; + + let presigned_req = match client + .get_object() + .bucket(bucket) + .key(key) + .presigned(presigning_config) + .await + { + Ok(req) => req, + Err(e) => return Err(SharesDecodingError::PresignedRequestError(e)), + }; + + // Return the presigned URL + Ok(presigned_req.uri().to_string()) +} diff --git a/iris-mpc-common/tests/smpc_request.rs b/iris-mpc-common/tests/smpc_request.rs new file mode 100644 index 000000000..5df955376 --- /dev/null +++ b/iris-mpc-common/tests/smpc_request.rs @@ -0,0 +1,153 @@ +mod tests { + use base64::{engine::general_purpose::STANDARD, Engine}; + use http::StatusCode; + use iris_mpc_common::helpers::{ + key_pair::{SharesDecodingError, SharesEncryptionKeyPair}, + smpc_request::{IrisCodesJSON, SMPCRequest}, + }; + use serde_json::json; + use sodiumoxide::crypto::{box_::PublicKey, sealedbox}; + use wiremock::{ + matchers::{method, path}, + Mock, MockServer, ResponseTemplate, + }; + + const PUBLIC_KEY: &str = "HDp962tQyZIG9t+GX4JM0i1wgJx/YGpHGsuDSD34KBA="; + const PRIVATE_KEY: &str = "14Z6Zijg3kbFN//R9BRKLeTS/wCiZMfK6AurEr/nAZg="; + + fn get_key_pair() -> SharesEncryptionKeyPair { + SharesEncryptionKeyPair::from_b64_strings( + PUBLIC_KEY.to_string().clone(), + PRIVATE_KEY.to_string().clone(), + ) + .unwrap() + } + + fn get_mock_request() -> SMPCRequest { + SMPCRequest { + batch_size: None, + signup_id: "test_signup_id".to_string(), + s3_presigned_url: "https://example.com/package".to_string(), + iris_shares_file_hashes: [ + "hash_0".to_string(), + "hash_1".to_string(), + "hash_2".to_string(), + ], + } + } + + #[tokio::test] + async fn test_retrieve_iris_shares_from_s3_success() { + let mock_server = MockServer::start().await; + + // Simulate a successful response from the presigned URL + let response_body = json!({ + "iris_share_0": "share_0_data", + "iris_share_1": "share_1_data", + "iris_share_2": "share_2_data" + }); + + let template = ResponseTemplate::new(StatusCode::OK).set_body_json(response_body.clone()); + + Mock::given(method("GET")) + .and(path("/test_presign_url")) + .respond_with(template) + .mount(&mock_server) + .await; + + let smpc_request = SMPCRequest { + batch_size: None, + signup_id: "test_signup_id".to_string(), + s3_presigned_url: mock_server.uri().clone() + "/test_presign_url", + iris_shares_file_hashes: [ + "hash_0".to_string(), + "hash_1".to_string(), + "hash_2".to_string(), + ], + }; + + let result = smpc_request.get_iris_data_by_party_id(0).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "share_0_data".to_string()); + } + + #[tokio::test] + async fn test_decrypt_iris_share_success() { + // Mocked base64 encoded JSON string + let iris_codes_json = IrisCodesJSON { + iris_version: "1.0".to_string(), + left_iris_code_shares: "left_code".to_string(), + right_iris_code_shares: "right_code".to_string(), + left_iris_mask_shares: "left_mask".to_string(), + right_iris_mask_shares: "right_mask".to_string(), + }; + + let decoded_public_key = STANDARD.decode(PUBLIC_KEY.as_bytes()).unwrap(); + let shares_encryption_public_key = PublicKey::from_slice(&decoded_public_key).unwrap(); + + // convert iris code to JSON string, sealbox and encode as BASE64 + let json_string = serde_json::to_string(&iris_codes_json).unwrap(); + let sealed_box = sealedbox::seal(json_string.as_bytes(), &shares_encryption_public_key); + let encoded_share = STANDARD.encode(sealed_box); + + let smpc_request = get_mock_request(); + let key_pair = get_key_pair(); + + let result = smpc_request.decrypt_iris_share(encoded_share, key_pair); + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), iris_codes_json); + } + + #[tokio::test] + async fn test_decrypt_iris_share_invalid_base64() { + let invalid_base64 = "InvalidBase64String"; + let key_pair = get_key_pair(); + let smpc_request = get_mock_request(); + + let result = smpc_request.decrypt_iris_share(invalid_base64.to_string(), key_pair); + + assert!(matches!( + result, + Err(SharesDecodingError::Base64DecodeError) + )); + } + + #[tokio::test] + async fn test_decrypt_iris_share_invalid_utf8() { + let invalid_utf8 = vec![0, 159, 146, 150]; // Not valid UTF-8 + + let decoded_public_key = STANDARD.decode(PUBLIC_KEY.as_bytes()).unwrap(); + let shares_encryption_public_key = PublicKey::from_slice(&decoded_public_key).unwrap(); + let sealed_box = sealedbox::seal(&invalid_utf8, &shares_encryption_public_key); + let encoded_share = STANDARD.encode(&sealed_box); + + let key_pair = get_key_pair(); + let smpc_request = get_mock_request(); + + let result = smpc_request.decrypt_iris_share(encoded_share, key_pair); + + assert!(matches!( + result, + Err(SharesDecodingError::DecodedShareParsingToUTF8Error(_)) + )); + } + + #[tokio::test] + async fn test_decrypt_iris_share_invalid_json() { + let invalid_json = "totally-not-a-json-string"; + + let decoded_public_key = STANDARD.decode(PUBLIC_KEY.as_bytes()).unwrap(); + let shares_encryption_public_key = PublicKey::from_slice(&decoded_public_key).unwrap(); + let sealed_box = sealedbox::seal(invalid_json.as_bytes(), &shares_encryption_public_key); + let encoded_share = STANDARD.encode(&sealed_box); + + let key_pair = get_key_pair(); + let smpc_request = get_mock_request(); + + let result = smpc_request.decrypt_iris_share(encoded_share, key_pair); + + assert!(matches!(result, Err(SharesDecodingError::SerdeError(_)))); + } +} diff --git a/iris-mpc-gpu/Cargo.toml b/iris-mpc-gpu/Cargo.toml index 19354253c..c72aa4b6b 100644 --- a/iris-mpc-gpu/Cargo.toml +++ b/iris-mpc-gpu/Cargo.toml @@ -24,8 +24,9 @@ rand.workspace = true static_assertions.workspace = true serde.workspace = true serde_json.workspace = true - +sodiumoxide = "0.2.7" iris-mpc-common = { path = "../iris-mpc-common" } +base64 = "0.22.1" [dev-dependencies] criterion = "0.5" diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index dec69d041..7f763cec0 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -1155,7 +1155,7 @@ mod tests { let mask = results_masks[0][i] + results_masks[1][i] + results_masks[2][i]; if i == 0 { - println!("Code: {}, Mask: {}", code, mask); + tracing::info!("Code: {}, Mask: {}", code, mask); } reconstructed_codes.push(code); diff --git a/iris-mpc-gpu/src/helpers/device_manager.rs b/iris-mpc-gpu/src/helpers/device_manager.rs index 3c38a86b2..496be2cf0 100644 --- a/iris-mpc-gpu/src/helpers/device_manager.rs +++ b/iris-mpc-gpu/src/helpers/device_manager.rs @@ -26,7 +26,7 @@ impl DeviceManager { devices.push(CudaDevice::new(i as usize).unwrap()); } - println!("Found {} devices", devices.len()); + tracing::info!("Found {} devices", devices.len()); Self { devices } } diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index b3c381495..9da203db9 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -278,12 +278,14 @@ impl ServerActor { // Not divided by GPU_COUNT since we do the work on all GPUs for simplicity, // also not padded to 2048 since we only require it to be a multiple of 64 let phase2_batch_chunk_size = QUERIES * QUERIES; - assert!( - phase2_batch_chunk_size % 64 == 0, + assert_eq!( + phase2_batch_chunk_size % 64, + 0, "Phase2 batch chunk size must be a multiple of 64" ); - assert!( - phase2_chunk_size % 64 == 0, + assert_eq!( + phase2_chunk_size % 64, + 0, "Phase2 chunk size must be a multiple of 64" ); diff --git a/iris-mpc-store/src/lib.rs b/iris-mpc-store/src/lib.rs index 9ba8a5aa2..928200d96 100644 --- a/iris-mpc-store/src/lib.rs +++ b/iris-mpc-store/src/lib.rs @@ -306,7 +306,7 @@ mod tests { use super::*; use futures::TryStreamExt; - use iris_mpc_common::helpers::sqs::ResultEvent; + use iris_mpc_common::helpers::smpc_request::ResultEvent; #[tokio::test] async fn test_store() -> Result<()> { diff --git a/iris-mpc/Cargo.toml b/iris-mpc/Cargo.toml index 90bfc4f33..a3fa10517 100644 --- a/iris-mpc/Cargo.toml +++ b/iris-mpc/Cargo.toml @@ -24,10 +24,11 @@ rand.workspace = true base64.workspace = true uuid.workspace = true - +sodiumoxide = "0.2.7" iris-mpc-gpu = { path = "../iris-mpc-gpu" } iris-mpc-common = { path = "../iris-mpc-common" } iris-mpc-store = { path = "../iris-mpc-store" } +sha2 = "0.10.8" [dev-dependencies] criterion = "0.5" diff --git a/iris-mpc/src/bin/client.rs b/iris-mpc/src/bin/client.rs index 49c2cfe36..fb190f2d4 100644 --- a/iris-mpc/src/bin/client.rs +++ b/iris-mpc/src/bin/client.rs @@ -1,9 +1,5 @@ #![allow(clippy::needless_range_loop)] -use aws_sdk_sns::{ - config::Region, - types::{MessageAttributeValue, PublishBatchRequestEntry}, - Client, -}; +use aws_sdk_sns::{config::Region, Client}; use aws_sdk_sqs::Client as SqsClient; use base64::{engine::general_purpose, Engine}; use clap::Parser; @@ -11,41 +7,58 @@ use eyre::{Context, ContextCompat}; use iris_mpc_common::{ galois_engine::degree4::GaloisRingIrisCodeShare, helpers::{ - aws::{construct_message_attributes, NODE_ID_MESSAGE_ATTRIBUTE_NAME}, - sqs::{ResultEvent, SMPCRequest}, + key_pair::download_public_key, + smpc_request::{IrisCodesJSON, ResultEvent, SMPCRequest}, + sqs_s3_helper::upload_file_and_generate_presigned_url, }, iris_db::{db::IrisDB, iris::IrisCode}, }; use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng}; use serde_json::to_string; +use sha2::{Digest, Sha256}; +use sodiumoxide::crypto::{box_::PublicKey, sealedbox}; use std::{collections::HashMap, sync::Arc, time::Duration}; use tokio::{spawn, sync::Mutex, time::sleep}; use uuid::Uuid; const N_QUERIES: usize = 64 * 5; -const REGION: &str = "eu-north-1"; const RNG_SEED_SERVER: u64 = 42; const DB_SIZE: usize = 8 * 1_000; const ENROLLMENT_REQUEST_TYPE: &str = "enrollment"; #[derive(Debug, Parser)] struct Opt { - #[arg(short, long, env)] + #[arg(long, env, required = true)] request_topic_arn: String, - #[arg(short, long, env)] + #[arg(long, env, required = true)] + request_topic_region: String, + + #[arg(long, env, required = true)] response_queue_url: String, - #[arg(short, long, env)] + #[arg(long, env, required = true)] + response_queue_region: String, + + #[arg(long, env, required = true)] + requests_bucket_name: String, + + #[arg(long, env, required = true)] + public_key_base_url: String, + + #[arg(long, env, required = true)] + requests_bucket_region: String, + + #[arg(long, env)] db_index: Option, - #[arg(short, long, env)] + #[arg(long, env)] rng_seed: Option, - #[arg(short, long, env)] + #[arg(long, env)] n_repeat: Option, - #[arg(short, long, env)] + #[arg(long, env)] random: Option, } @@ -54,8 +67,17 @@ async fn main() -> eyre::Result<()> { tracing_subscriber::fmt::init(); let Opt { + public_key_base_url, + + requests_bucket_name, + requests_bucket_region, + request_topic_arn, + request_topic_region, + response_queue_url, + response_queue_region, + db_index, rng_seed, n_repeat, @@ -68,11 +90,25 @@ async fn main() -> eyre::Result<()> { StdRng::from_entropy() }; + let mut shares_encryption_public_keys: Vec = vec![]; + + for i in 0..3 { + let public_key_string = + download_public_key(public_key_base_url.to_string(), i.to_string()).await?; + let public_key_bytes = general_purpose::STANDARD + .decode(public_key_string) + .context("Failed to decode public key")?; + let public_key = + PublicKey::from_slice(&public_key_bytes).context("Failed to parse public key")?; + shares_encryption_public_keys.push(public_key); + } + let n_repeat = n_repeat.unwrap_or(0); - let region_provider = Region::new(REGION); - let shared_config = aws_config::from_env().region(region_provider).load().await; - let client = Client::new(&shared_config); + let region_provider = Region::new(request_topic_region); + let requests_sns_config = aws_config::from_env().region(region_provider).load().await; + + let requests_sns_client = Client::new(&requests_sns_config); let db = IrisDB::new_random_par(DB_SIZE, &mut StdRng::seed_from_u64(RNG_SEED_SERVER)); @@ -88,10 +124,15 @@ async fn main() -> eyre::Result<()> { let thread_responses = responses.clone(); let recv_thread = spawn(async move { - let sqs_client = SqsClient::new(&shared_config); - for _ in 0..N_QUERIES * 3 { + // // THIS IS REQUIRED TO USE THE SQS FROM SECONDARY REGION, URL DOES NOT + // SUFFICE + let region_provider = Region::new(response_queue_region); + let results_sqs_config = aws_config::from_env().region(region_provider).load().await; + let results_sqs_client = SqsClient::new(&results_sqs_config); + let mut counter = 0; + while counter < N_QUERIES * 3 { // Receive responses - let msg = sqs_client + let msg = results_sqs_client .receive_message() .max_number_of_messages(1) .queue_url(response_queue_url.clone()) @@ -100,6 +141,7 @@ async fn main() -> eyre::Result<()> { .context("Failed to receive message")?; for msg in msg.messages.unwrap_or_default() { + counter += 1; let result: ResultEvent = serde_json::from_str(&msg.body.context("No body found")?) .context("Failed to parse message body")?; @@ -113,6 +155,15 @@ async fn main() -> eyre::Result<()> { stale, clear the queue", result.signup_id ); + + results_sqs_client + .delete_message() + .queue_url(response_queue_url.clone()) + .receipt_handle(msg.receipt_handle.unwrap()) + .send() + .await + .context("Failed to delete message")?; + continue; } let expected_result = expected_result.unwrap(); @@ -140,7 +191,7 @@ async fn main() -> eyre::Result<()> { assert_eq!(result.serial_id.unwrap(), expected_result.unwrap()); } - sqs_client + results_sqs_client .delete_message() .queue_url(response_queue_url.clone()) .receipt_handle(msg.receipt_handle.unwrap()) @@ -225,46 +276,71 @@ async fn main() -> eyre::Result<()> { &mut StdRng::seed_from_u64(RNG_SEED_SERVER), ); - let mut messages = vec![]; + let mut iris_shares_file_hashes: [String; 3] = Default::default(); + let mut iris_codes_shares_base64: [String; 3] = Default::default(); + for i in 0..3 { - let sns_id = Uuid::new_v4(); - let iris_code = + let iris_code_coefs_base64 = general_purpose::STANDARD.encode(bytemuck::cast_slice(&shared_code[i].coefs)); - let mask_code = + let mask_code_coefs_base64 = general_purpose::STANDARD.encode(bytemuck::cast_slice(&shared_mask[i].coefs)); - let request_message = SMPCRequest { - batch_size: None, - request_id: request_id.to_string(), - iris_code, - mask_code, + let iris_codes_json = IrisCodesJSON { + iris_version: "1.0".to_string(), + right_iris_code_shares: iris_code_coefs_base64, + right_iris_mask_shares: mask_code_coefs_base64, + left_iris_code_shares: "nan".to_string(), + left_iris_mask_shares: "nan".to_string(), }; - - let mut message_attributes = construct_message_attributes()?; - message_attributes.insert( - NODE_ID_MESSAGE_ATTRIBUTE_NAME.to_string(), - MessageAttributeValue::builder() - .data_type("String") - .string_value(i.to_string()) - .build()?, + let serialized_iris_codes_json = to_string(&iris_codes_json) + .expect("Serialization failed") + .clone(); + + // calculate hash of the object + let mut hasher = Sha256::new(); + hasher.update(serialized_iris_codes_json.clone()); + let result = hasher.finalize(); + let hash_string = format!("{:x}", result); + + // encrypt the object using sealed box and public key + let encrypted_bytes = sealedbox::seal( + serialized_iris_codes_json.as_bytes(), + &shares_encryption_public_keys[i], ); - messages.push( - PublishBatchRequestEntry::builder() - .message(to_string(&request_message)?) - .id(sns_id.to_string()) - .message_group_id(ENROLLMENT_REQUEST_TYPE) - .set_message_attributes(Some(message_attributes)) - .build() - .unwrap(), - ); + iris_codes_shares_base64[i] = general_purpose::STANDARD.encode(&encrypted_bytes); + iris_shares_file_hashes[i] = hash_string; } + let contents = serde_json::to_vec(&iris_codes_shares_base64)?; + let presigned_url = match upload_file_and_generate_presigned_url( + &requests_bucket_name, + &request_id.to_string(), + Box::leak(requests_bucket_region.clone().into_boxed_str()), + &contents, + ) + .await + { + Ok(url) => url, + Err(e) => { + eprintln!("Failed to upload file: {}", e); + continue; + } + }; + + let request_message = SMPCRequest { + batch_size: None, + signup_id: request_id.to_string(), + s3_presigned_url: presigned_url, + iris_shares_file_hashes, + }; + // Send all messages in batch - client - .publish_batch() + requests_sns_client + .publish() .topic_arn(request_topic_arn.clone()) - .set_publish_batch_request_entries(Some(messages)) + .message_group_id(ENROLLMENT_REQUEST_TYPE) + .message(to_string(&request_message)?) .send() .await?; diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index 1156c2190..192eaaa70 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -10,12 +10,10 @@ use iris_mpc_common::{ config::{json_wrapper::JsonStrWrapper, Config, Opt}, galois_engine::degree4::GaloisRingIrisCodeShare, helpers::{ - aws::{ - NODE_ID_MESSAGE_ATTRIBUTE_NAME, SPAN_ID_MESSAGE_ATTRIBUTE_NAME, - TRACE_ID_MESSAGE_ATTRIBUTE_NAME, - }, + aws::{SPAN_ID_MESSAGE_ATTRIBUTE_NAME, TRACE_ID_MESSAGE_ATTRIBUTE_NAME}, + key_pair::SharesEncryptionKeyPair, kms_dh::derive_shared_secret, - sqs::{ResultEvent, SMPCRequest, SQSMessage}, + smpc_request::{ResultEvent, SMPCRequest, SQSMessage}, sync::SyncState, task_monitor::TaskMonitor, }, @@ -64,6 +62,7 @@ async fn receive_batch( queue_url: &String, store: &Store, skip_request_ids: &[String], + shares_encryption_key_pair: SharesEncryptionKeyPair, ) -> eyre::Result { let mut batch_query = BatchQuery::default(); @@ -78,13 +77,15 @@ async fn receive_batch( if let Some(messages) = rcv_message_output.messages { for sqs_message in messages { + let shares_encryption_key_pair = shares_encryption_key_pair.clone(); let message: SQSMessage = serde_json::from_str(sqs_message.body().unwrap()) .context("while trying to parse SQSMessage")?; - let message: SMPCRequest = serde_json::from_str(&message.message) + + let smpc_request: SMPCRequest = serde_json::from_str(&message.message) .context("while trying to parse SMPCRequest")?; store - .mark_requests_deleted(&[message.request_id.clone()]) + .mark_requests_deleted(&[smpc_request.signup_id.clone()]) .await .context("while marking requests as deleted")?; @@ -96,12 +97,12 @@ async fn receive_batch( .await .context("while calling `delete_message` on SQS client")?; - if skip_request_ids.contains(&message.request_id) { + if skip_request_ids.contains(&smpc_request.signup_id) { // Some party (maybe us) already meant to delete this request, so we skip it. continue; } - if let Some(batch_size) = message.batch_size { + if let Some(batch_size) = smpc_request.batch_size { // Updating the batch size instantly makes it a bit unpredictable, since // if we're already above the new limit, we'll still process the current batch // at the higher limit. On the other hand, updating it after the batch is @@ -115,10 +116,6 @@ async fn receive_batch( let mut batch_metadata = BatchMetadata::default(); - if let Some(node_id) = message_attributes.get(NODE_ID_MESSAGE_ATTRIBUTE_NAME) { - let node_id = node_id.string_value().unwrap(); - batch_metadata.node_id = node_id.to_string(); - } if let Some(trace_id) = message_attributes.get(TRACE_ID_MESSAGE_ATTRIBUTE_NAME) { let trace_id = trace_id.string_value().unwrap(); batch_metadata.trace_id = trace_id.to_string(); @@ -128,9 +125,29 @@ async fn receive_batch( batch_metadata.span_id = span_id.to_string(); } - batch_query.request_ids.push(message.request_id.clone()); + batch_query.request_ids.push(smpc_request.signup_id.clone()); batch_query.metadata.push(batch_metadata); + let base_64_encoded_message_payload = + match smpc_request.get_iris_data_by_party_id(party_id).await { + Ok(iris_message_share) => iris_message_share, + Err(e) => { + tracing::error!("Failed to get iris shares: {:?}", e); + continue; + } + }; + + let iris_message_share = match smpc_request.decrypt_iris_share( + base_64_encoded_message_payload, + shares_encryption_key_pair.clone(), + ) { + Ok(iris_data) => iris_data, + Err(e) => { + tracing::error!("Failed to decrypt iris shares: {:?}", e); + continue; + } + }; + let ( store_iris_shares, store_mask_shares, @@ -139,10 +156,26 @@ async fn receive_batch( iris_shares, mask_shares, ) = spawn_blocking(move || { - let mut iris_share = - GaloisRingIrisCodeShare::new(party_id + 1, message.get_iris_shares()); - let mut mask_share = - GaloisRingIrisCodeShare::new(party_id + 1, message.get_mask_shares()); + let mut iris_share = match GaloisRingIrisCodeShare::from_base64( + party_id + 1, + iris_message_share.right_iris_code_shares.as_ref(), + ) { + Ok(iris_share) => iris_share, + Err(e) => { + tracing::error!("Failed to parse iris share: {:?}", e); + return Err(e); + } + }; + let mut mask_share = match GaloisRingIrisCodeShare::from_base64( + party_id + 1, + iris_message_share.right_iris_mask_shares.as_ref(), + ) { + Ok(iris_share) => iris_share, + Err(e) => { + tracing::error!("Failed to parse iris mask: {:?}", e); + return Err(e); + } + }; // Original for storage. let store_iris_shares = iris_share.clone(); @@ -156,17 +189,17 @@ async fn receive_batch( GaloisRingIrisCodeShare::preprocess_iris_code_query_share(&mut iris_share); GaloisRingIrisCodeShare::preprocess_iris_code_query_share(&mut mask_share); - ( + Ok(( store_iris_shares, store_mask_shares, db_iris_shares, db_mask_shares, iris_share.all_rotations(), mask_share.all_rotations(), - ) + )) }) .await - .context("while pre-processing iris code query")?; + .context("while pre-processing iris code query")??; batch_query.store_left.code.push(store_iris_shares); batch_query.store_left.mask.push(store_mask_shares); @@ -393,6 +426,14 @@ async fn server_main(config: Config) -> eyre::Result<()> { let shared_config = aws_config::from_env().region(region_provider).load().await; let sqs_client = Client::new(&shared_config); let sns_client = SNSClient::new(&shared_config); + let shares_encryption_key_pair = + match SharesEncryptionKeyPair::from_storage(config.clone()).await { + Ok(key_pair) => key_pair, + Err(e) => { + tracing::error!("Failed to initialize shares encryption key pair: {:?}", e); + return Ok(()); + } + }; let party_id = config.party_id; tracing::info!("Deriving shared secrets"); @@ -604,18 +645,19 @@ async fn server_main(config: Config) -> eyre::Result<()> { // - The outer Vec is the dimension of the Galois Ring (2): // - A decomposition of each iris bit into two u8 limbs. - // This batch can consist of N sets of iris_share + mask - // It also includes a vector of request ids, mapping to the sets above // Skip requests based on the startup sync, only in the first iteration. let skip_request_ids = mem::take(&mut skip_request_ids); + let shares_encryption_key_pair = shares_encryption_key_pair.clone(); + // This batch can consist of N sets of iris_share + mask + // It also includes a vector of request ids, mapping to the sets above let mut next_batch = receive_batch( party_id, &sqs_client, &config.requests_queue_url, &store, &skip_request_ids, + shares_encryption_key_pair.clone(), ); - loop { let now = Instant::now(); @@ -645,6 +687,7 @@ async fn server_main(config: Config) -> eyre::Result<()> { &config.requests_queue_url, &store, &skip_request_ids, + shares_encryption_key_pair.clone(), ); // await the result