From 6dcc548abd0fb28bd565c11eb07b3936be29c043 Mon Sep 17 00:00:00 2001 From: Roman Walch <9820846+rw0x0@users.noreply.github.com> Date: Tue, 14 Jan 2025 13:09:02 +0100 Subject: [PATCH] use the open_bucket function in the testcase --- iris-mpc-gpu/src/server/actor.rs | 2 +- iris-mpc-gpu/src/threshold_ring/protocol.rs | 6 +--- iris-mpc-gpu/tests/buckets.rs | 29 ++----------------- iris-mpc-gpu/tests/one_bucket.rs | 31 ++------------------- 4 files changed, 7 insertions(+), 61 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 78457ad28..dad3f3280 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -1152,7 +1152,7 @@ impl ServerActor { let buckets = self .phase2_buckets - .open_buckets(&mut self.buckets, batch_streams); + .open_buckets(&self.buckets, batch_streams); tracing::info!("BUCKETRESULT: {:?}", buckets); diff --git a/iris-mpc-gpu/src/threshold_ring/protocol.rs b/iris-mpc-gpu/src/threshold_ring/protocol.rs index 5e126bbe9..90a969d32 100644 --- a/iris-mpc-gpu/src/threshold_ring/protocol.rs +++ b/iris-mpc-gpu/src/threshold_ring/protocol.rs @@ -2525,11 +2525,7 @@ impl Circuits { self.buffers.check_buffers(); } - pub fn open_buckets( - &mut self, - buckets: &mut ChunkShare, - streams: &[CudaStream], - ) -> Vec { + pub fn open_buckets(&mut self, buckets: &ChunkShare, streams: &[CudaStream]) -> Vec { let a = dtoh_on_stream_sync(&buckets.a, &self.devs[0], &streams[0]).unwrap(); let b = dtoh_on_stream_sync(&buckets.b, &self.devs[0], &streams[0]).unwrap(); let mut res = buckets.as_view(); diff --git a/iris-mpc-gpu/tests/buckets.rs b/iris-mpc-gpu/tests/buckets.rs index 0fe6f813f..2647db2c3 100644 --- a/iris-mpc-gpu/tests/buckets.rs +++ b/iris-mpc-gpu/tests/buckets.rs @@ -6,7 +6,7 @@ mod buckets_test { }; use iris_mpc_common::iris_db::iris::IrisCodeArray; use iris_mpc_gpu::{ - helpers::{device_manager::DeviceManager, dtoh_on_stream_sync, htod_on_stream_sync}, + helpers::{device_manager::DeviceManager, htod_on_stream_sync}, threshold_ring::protocol::{ChunkShare, Circuits}, }; use itertools::{izip, Itertools}; @@ -110,31 +110,6 @@ mod buckets_test { result } - fn open(party: &mut Circuits, x: &ChunkShare, streams: &[CudaStream]) -> Vec { - let mut view = x.as_view(); - let dev = party.get_devices()[0].clone(); - - let mut a = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap(); - cudarc::nccl::result::group_start().unwrap(); - // Result is in bit 0 - party.comms()[0] - .send_view(&view.b, party.next_id(), &streams[0]) - .unwrap(); - - party.comms()[0] - .receive_view(&mut view.a, party.prev_id(), &streams[0]) - .unwrap(); - cudarc::nccl::result::group_end().unwrap(); - let b = dtoh_on_stream_sync(&x.b, &dev, &streams[0]).unwrap(); - let c = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap(); - - for (a, b, c) in izip!(a.iter_mut(), b, c) { - *a += b + c; - } - - a - } - fn install_tracing() { tracing_subscriber::registry() .with( @@ -195,7 +170,7 @@ mod buckets_test { tracing::info!("id: {}, compute time: {:?}", id, now.elapsed()); let now = Instant::now(); - let result = open(&mut party, &bucket, &streams); + let result = party.open_buckets(&bucket, &streams); party.synchronize_streams(&streams); tracing::info!("id: {}, Starting tests...", id); tracing::info!( diff --git a/iris-mpc-gpu/tests/one_bucket.rs b/iris-mpc-gpu/tests/one_bucket.rs index 5b4004d4e..857056776 100644 --- a/iris-mpc-gpu/tests/one_bucket.rs +++ b/iris-mpc-gpu/tests/one_bucket.rs @@ -1,4 +1,4 @@ -#[cfg(feature = "gpu_dependent")] +// #[cfg(feature = "gpu_dependent")] mod one_bucket_test { use cudarc::{ driver::{CudaDevice, CudaStream}, @@ -6,7 +6,7 @@ mod one_bucket_test { }; use iris_mpc_common::iris_db::iris::{IrisCodeArray, MATCH_THRESHOLD_RATIO}; use iris_mpc_gpu::{ - helpers::{device_manager::DeviceManager, dtoh_on_stream_sync, htod_on_stream_sync}, + helpers::{device_manager::DeviceManager, htod_on_stream_sync}, threshold_ring::protocol::{ChunkShare, Circuits}, }; use itertools::{izip, Itertools}; @@ -101,31 +101,6 @@ mod one_bucket_test { count } - fn open(party: &mut Circuits, x: &ChunkShare, streams: &[CudaStream]) -> u32 { - let mut view = x.as_view(); - let dev = party.get_devices()[0].clone(); - - let mut a = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap(); - cudarc::nccl::result::group_start().unwrap(); - // Result is in bit 0 - party.comms()[0] - .send_view(&view.b, party.next_id(), &streams[0]) - .unwrap(); - - party.comms()[0] - .receive_view(&mut view.a, party.prev_id(), &streams[0]) - .unwrap(); - cudarc::nccl::result::group_end().unwrap(); - let b = dtoh_on_stream_sync(&x.b, &dev, &streams[0]).unwrap(); - let c = dtoh_on_stream_sync(&x.a, &dev, &streams[0]).unwrap(); - - for (a, b, c) in izip!(a.iter_mut(), b, c) { - *a += b + c; - } - - a[0] - } - fn install_tracing() { tracing_subscriber::registry() .with( @@ -180,7 +155,7 @@ mod one_bucket_test { tracing::info!("id: {}, compute time: {:?}", id, now.elapsed()); let now = Instant::now(); - let result = open(&mut party, &bucket, &streams); + let result = party.open_buckets(&bucket, &streams)[0]; party.synchronize_streams(&streams); tracing::info!("id: {}, Starting tests...", id); tracing::info!(