Skip to content

Commit

Permalink
Merge branch 'main' into fix/adjusting-configs-for-prod
Browse files Browse the repository at this point in the history
  • Loading branch information
wojciechsromek committed Aug 28, 2024
2 parents 676c7d4 + f7219b6 commit cfb24d4
Show file tree
Hide file tree
Showing 21 changed files with 519 additions and 375 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/build-and-push-debug.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: Build and push docker Debug image

on:
workflow_dispatch:

concurrency:
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
cancel-in-progress: true

env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-debug

jobs:
docker:
runs-on:
labels: ubuntu-22.04-64core
permissions:
packages: write
contents: read
attestations: write
id-token: write
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to the Container registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build and Push
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: |
${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.sha }}
platforms: linux/amd64
cache-from: type=gha
cache-to: type=gha,mode=max
file: Dockerfile.debug
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 24 additions & 35 deletions iris-mpc-common/src/bin/shares_encoding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use base64::prelude::{Engine, BASE64_STANDARD};
use clap::Parser;
use data_encoding::HEXLOWER;
use iris_mpc_common::{
galois_engine::degree4::GaloisRingIrisCodeShare, iris_db::iris::IrisCodeArray,
Expand All @@ -9,7 +8,7 @@ use ring::digest::{digest, SHA256};
use serde::{ser::Error, Serialize, Serializer};
use serde_big_array::BigArray;
use serde_json::Value;
use std::collections::BTreeMap;
use std::{collections::BTreeMap, env};

const RNG_SEED: u64 = 42; // Replace with your seed value
const IRIS_VERSION: &str = "1.1";
Expand Down Expand Up @@ -37,9 +36,9 @@ struct IrisCodeSharesJson {
#[serde(rename = "IRIS_shares_version")]
iris_shares_version: String,
left_iris_code_shares: String,
left_mask_code_shares: String,
left_iris_mask_shares: String,
right_iris_code_shares: String,
right_mask_code_shares: String,
right_iris_mask_shares: String,
}

/// Iris code shares.
Expand Down Expand Up @@ -98,54 +97,44 @@ fn to_array(input: [GaloisRingIrisCodeShare; 3]) -> [IrisCodeShare; 3] {
.expect("Expected exactly 3 elements")
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
iris_b64_left: Option<String>,

#[arg(long)]
mask_b64_left: Option<String>,

#[arg(long)]
iris_b64_right: Option<String>,

#[arg(long)]
mask_b64_right: Option<String>,

#[arg(short, long, env)]
rng_seed: Option<u64>,
}

fn main() {
let args = Args::parse();
let mut rng = if let Some(seed_rng) = args.rng_seed {
StdRng::seed_from_u64(seed_rng)
let mut rng = if let Ok(seed_rng) = env::var("SEED_RNG") {
// env variable passed, use passed seed
StdRng::seed_from_u64(seed_rng.parse().unwrap())
} else {
// no env variable passed, use default seed
StdRng::seed_from_u64(RNG_SEED)
};

let iris_code_left = if let Some(iris_base_64) = args.iris_b64_right {
let iris_code_left = if let Ok(iris_base_64) = env::var("IRIS_B64_LEFT") {
// env variable passed, use passed iris code
IrisCodeArray::from_base64(&iris_base_64).unwrap()
} else {
// no env variable passed, generate random iris code
IrisCodeArray::random_rng(&mut rng)
};

let mask_code_left = if let Some(mask_base_64) = args.mask_b64_right {
let mask_code_left = if let Ok(mask_base_64) = env::var("MASK_B64_LEFT") {
// env variable passed, use passed iris mask
IrisCodeArray::from_base64(&mask_base_64).unwrap()
} else {
// no env variable passed, use default iris mask
IrisCodeArray::default()
};

let iris_code_right = if let Some(iris_base_64) = args.iris_b64_left {
let iris_code_right = if let Ok(iris_base_64) = env::var("IRIS_B64_RIGHT") {
// env variable passed, use passed iris code
IrisCodeArray::from_base64(&iris_base_64).unwrap()
} else {
// no env variable passed, generate random iris code
IrisCodeArray::random_rng(&mut rng)
};

let mask_code_right = if let Some(mask_base_64) = args.mask_b64_left {
let mask_code_right = if let Ok(mask_base_64) = env::var("MASK_B64_RIGHT") {
// env variable passed, use passed iris mask
IrisCodeArray::from_base64(&mask_base_64).unwrap()
} else {
// no env variable passed, use default iris mask
IrisCodeArray::default()
};

Expand Down Expand Up @@ -181,9 +170,9 @@ fn main() {
iris_version: IRIS_VERSION.to_string(),
iris_shares_version: IRIS_MPC_VERSION.to_string(),
left_iris_code_shares: li.into(),
left_mask_code_shares: lm.into(),
left_iris_mask_shares: lm.into(),
right_iris_code_shares: ri.into(),
right_mask_code_shares: rm.into(),
right_iris_mask_shares: rm.into(),
};
let json_u8 = serde_json::to_string(&SerializeWithSortedKeys(&iris_code_shares))
.unwrap()
Expand Down Expand Up @@ -216,12 +205,12 @@ mod tests {
iris_version: IRIS_VERSION.to_string(),
iris_shares_version: IRIS_MPC_VERSION.to_string(),
left_iris_code_shares: "left_iris_code_shares".to_string(),
left_mask_code_shares: "left_mask_code_shares".to_string(),
left_iris_mask_shares: "left_iris_mask_shares".to_string(),
right_iris_code_shares: "right_iris_code_shares".to_string(),
right_mask_code_shares: "right_mask_code_shares".to_string(),
right_iris_mask_shares: "right_iris_mask_shares".to_string(),
};

let expected = r#"{"IRIS_shares_version":"1.0","IRIS_version":"1.1","left_iris_code_shares":"left_iris_code_shares","left_mask_code_shares":"left_mask_code_shares","right_iris_code_shares":"right_iris_code_shares","right_mask_code_shares":"right_mask_code_shares"}"#;
let expected = r#"{"IRIS_shares_version":"1.0","IRIS_version":"1.1","left_iris_code_shares":"left_iris_code_shares","left_iris_mask_shares":"left_iris_mask_shares","right_iris_code_shares":"right_iris_code_shares","right_iris_mask_shares":"right_iris_mask_shares"}"#;
assert_eq!(
serde_json::to_string(&SerializeWithSortedKeys(&iris_code_shares)).unwrap(),
expected
Expand Down
4 changes: 1 addition & 3 deletions iris-mpc-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
cudarc = { git = "https://github.com/philsippl/cudarc.git", rev = "b824f30", features = [
"cuda-12020",
] }
cudarc = { version = "0.12", features = ["cuda-12020", "nccl"] }
eyre.workspace = true
tracing.workspace = true
bytemuck.workspace = true
Expand Down
14 changes: 4 additions & 10 deletions iris-mpc-gpu/src/bin/nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async fn main() -> eyre::Result<()> {

devs.push(dev);
comms.push(comm);
slices.push(slice);
slices.push(Some(slice));
slices1.push(slice1);
slices2.push(slice2);
slices3.push(slice3);
Expand All @@ -85,15 +85,9 @@ async fn main() -> eyre::Result<()> {
for i in 0..n_devices {
devs[i].bind_to_thread().unwrap();

comms[i]
.broadcast(&Some(&slices[i]), &mut slices1[i], 0)
.unwrap();
comms[i]
.broadcast(&Some(&slices[i]), &mut slices2[i], 1)
.unwrap();
comms[i]
.broadcast(&Some(&slices[i]), &mut slices3[i], 2)
.unwrap();
comms[i].broadcast(&slices[i], &mut slices1[i], 0).unwrap();
comms[i].broadcast(&slices[i], &mut slices2[i], 1).unwrap();
comms[i].broadcast(&slices[i], &mut slices3[i], 2).unwrap();
}

for dev in devs.iter() {
Expand Down
69 changes: 12 additions & 57 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::IRIS_CODE_LENGTH;
use crate::{
helpers::{
comm::NcclComm,
device_manager::DeviceManager,
device_ptrs_to_shares,
query_processor::{
Expand All @@ -25,7 +26,7 @@ use cudarc::{
sys::{CUdeviceptr, CUmemAttach_flags},
CudaFunction, CudaSlice, CudaStream, DevicePtr, LaunchAsync, LaunchConfig,
},
nccl::{self, result, Comm, NcclType},
nccl,
nvrtc::compile_ptx,
};
#[cfg(feature = "otp_encrypt")]
Expand Down Expand Up @@ -112,44 +113,6 @@ pub fn gemm(
}
}

fn send_stream<T: NcclType>(
sendbuff: &CudaSlice<T>,
len: usize,
peer: usize,
comm: &Comm,
stream: &CudaStream,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::send(
*sendbuff.device_ptr() as *mut _,
len,
T::as_nccl_type(),
peer as i32,
comm.comm.0,
stream.stream as *mut _,
)
}
}

fn receive_stream<T: NcclType>(
recvbuff: &mut CudaSlice<T>,
len: usize,
peer: usize,
comm: &Comm,
stream: &CudaStream,
) -> Result<result::NcclStatus, result::NcclError> {
unsafe {
result::recv(
*recvbuff.device_ptr() as *mut _,
len,
T::as_nccl_type(),
peer as i32,
comm.comm.0,
stream.stream as *mut _,
)
}
}

fn chunking<T: Clone>(
slice: &[T],
n_chunks: usize,
Expand Down Expand Up @@ -186,7 +149,7 @@ pub struct ShareDB {
#[cfg(feature = "otp_encrypt")]
xor_assign_u8_kernels: Vec<CudaFunction>,
rngs: Vec<(ChaChaCudaRng, ChaChaCudaRng)>,
comms: Vec<Arc<Comm>>,
comms: Vec<Arc<NcclComm>>,
ones: Vec<CudaSlice<u8>>,
intermediate_results: Vec<CudaSlice<i32>>,
pub results: Vec<CudaSlice<u8>>,
Expand All @@ -202,7 +165,7 @@ impl ShareDB {
max_db_length: usize,
query_length: usize,
chacha_seeds: ([u32; 8], [u32; 8]),
comms: Vec<Arc<Comm>>,
comms: Vec<Arc<NcclComm>>,
) -> Self {
let n_devices = device_manager.device_count();
let ptx = compile_ptx(PTX_SRC).unwrap();
Expand Down Expand Up @@ -745,23 +708,15 @@ impl ShareDB {
let send_len = len >> 2;
#[cfg(not(feature = "otp_encrypt"))]
let send_len = len;
send_stream(
&send[idx],
send_len,
next_peer,
&self.comms[idx],
&streams[idx],
)
.unwrap();
let send_view = send[idx].slice(..send_len);
self.comms[idx]
.send_view(&send_view, next_peer, &streams[idx])
.unwrap();

receive_stream(
&mut self.results_peer[idx],
len,
prev_peer,
&self.comms[idx],
&streams[idx],
)
.unwrap();
let mut recv_view = self.results_peer[idx].slice(..len);
self.comms[idx]
.receive_view(&mut recv_view, prev_peer, &streams[idx])
.unwrap();
}
nccl::group_end().unwrap();
#[cfg(feature = "otp_encrypt")]
Expand Down
Loading

0 comments on commit cfb24d4

Please sign in to comment.