Skip to content

Commit

Permalink
use the open_bucket function in the testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Jan 14, 2025
1 parent d9c4c7e commit 6dcc548
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 61 deletions.
2 changes: 1 addition & 1 deletion iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ impl ServerActor {

let buckets = self
.phase2_buckets
.open_buckets(&mut self.buckets, batch_streams);
.open_buckets(&self.buckets, batch_streams);

tracing::info!("BUCKETRESULT: {:?}", buckets);

Expand Down
6 changes: 1 addition & 5 deletions iris-mpc-gpu/src/threshold_ring/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2525,11 +2525,7 @@ impl Circuits {
self.buffers.check_buffers();
}

pub fn open_buckets(
&mut self,
buckets: &mut ChunkShare<u32>,
streams: &[CudaStream],
) -> Vec<u32> {
pub fn open_buckets(&mut self, buckets: &ChunkShare<u32>, streams: &[CudaStream]) -> Vec<u32> {
let a = dtoh_on_stream_sync(&buckets.a, &self.devs[0], &streams[0]).unwrap();
let b = dtoh_on_stream_sync(&buckets.b, &self.devs[0], &streams[0]).unwrap();
let mut res = buckets.as_view();
Expand Down
29 changes: 2 additions & 27 deletions iris-mpc-gpu/tests/buckets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod buckets_test {
};
use iris_mpc_common::iris_db::iris::IrisCodeArray;
use iris_mpc_gpu::{
helpers::{device_manager::DeviceManager, dtoh_on_stream_sync, htod_on_stream_sync},
helpers::{device_manager::DeviceManager, htod_on_stream_sync},
threshold_ring::protocol::{ChunkShare, Circuits},
};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -110,31 +110,6 @@ mod buckets_test {
result
}

fn open(party: &mut Circuits, x: &ChunkShare<u32>, streams: &[CudaStream]) -> Vec<u32> {
let mut view = x.as_view();
let dev = party.get_devices()[0].clone();

let mut a = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap();
cudarc::nccl::result::group_start().unwrap();
// Result is in bit 0
party.comms()[0]
.send_view(&view.b, party.next_id(), &streams[0])
.unwrap();

party.comms()[0]
.receive_view(&mut view.a, party.prev_id(), &streams[0])
.unwrap();
cudarc::nccl::result::group_end().unwrap();
let b = dtoh_on_stream_sync(&x.b, &dev, &streams[0]).unwrap();
let c = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap();

for (a, b, c) in izip!(a.iter_mut(), b, c) {
*a += b + c;
}

a
}

fn install_tracing() {
tracing_subscriber::registry()
.with(
Expand Down Expand Up @@ -195,7 +170,7 @@ mod buckets_test {
tracing::info!("id: {}, compute time: {:?}", id, now.elapsed());

let now = Instant::now();
let result = open(&mut party, &bucket, &streams);
let result = party.open_buckets(&bucket, &streams);
party.synchronize_streams(&streams);
tracing::info!("id: {}, Starting tests...", id);
tracing::info!(
Expand Down
31 changes: 3 additions & 28 deletions iris-mpc-gpu/tests/one_bucket.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#[cfg(feature = "gpu_dependent")]
// #[cfg(feature = "gpu_dependent")]
mod one_bucket_test {
use cudarc::{
driver::{CudaDevice, CudaStream},
nccl::Id,
};
use iris_mpc_common::iris_db::iris::{IrisCodeArray, MATCH_THRESHOLD_RATIO};
use iris_mpc_gpu::{
helpers::{device_manager::DeviceManager, dtoh_on_stream_sync, htod_on_stream_sync},
helpers::{device_manager::DeviceManager, htod_on_stream_sync},
threshold_ring::protocol::{ChunkShare, Circuits},
};
use itertools::{izip, Itertools};
Expand Down Expand Up @@ -101,31 +101,6 @@ mod one_bucket_test {
count
}

fn open(party: &mut Circuits, x: &ChunkShare<u32>, streams: &[CudaStream]) -> u32 {
let mut view = x.as_view();
let dev = party.get_devices()[0].clone();

let mut a = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap();
cudarc::nccl::result::group_start().unwrap();
// Result is in bit 0
party.comms()[0]
.send_view(&view.b, party.next_id(), &streams[0])
.unwrap();

party.comms()[0]
.receive_view(&mut view.a, party.prev_id(), &streams[0])
.unwrap();
cudarc::nccl::result::group_end().unwrap();
let b = dtoh_on_stream_sync(&x.b, &dev, &streams[0]).unwrap();
let c = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap();

for (a, b, c) in izip!(a.iter_mut(), b, c) {
*a += b + c;
}

a[0]
}

fn install_tracing() {
tracing_subscriber::registry()
.with(
Expand Down Expand Up @@ -180,7 +155,7 @@ mod one_bucket_test {
tracing::info!("id: {}, compute time: {:?}", id, now.elapsed());

let now = Instant::now();
let result = open(&mut party, &bucket, &streams);
let result = party.open_buckets(&bucket, &streams)[0];
party.synchronize_streams(&streams);
tracing::info!("id: {}, Starting tests...", id);
tracing::info!(
Expand Down

0 comments on commit 6dcc548

Please sign in to comment.