diff --git a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu index 6784089a3..add247064 100644 --- a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu +++ b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu @@ -28,11 +28,11 @@ __device__ void arithmetic_xor_inner(T *res_a, T *lhs_a, T *lhs_b, T *rhs_a, T lhs_b_val = *lhs_b; T rhs_a_val = *rhs_a; T rhs_b_val = *rhs_b; - // T r1_val = *r1; - // T r2_val = *r2; + T r1_val = *r1; + T r2_val = *r2; T mul = (lhs_a_val * rhs_a_val) + (lhs_b_val * rhs_a_val) + - (lhs_a_val * rhs_b_val); // + r1_val - r2_val; + (lhs_a_val * rhs_b_val) + r1_val - r2_val; *res_a = lhs_a_val + rhs_a_val - 2 * mul; } diff --git a/iris-mpc-gpu/src/threshold_ring/protocol.rs b/iris-mpc-gpu/src/threshold_ring/protocol.rs index 34620eaac..597758326 100644 --- a/iris-mpc-gpu/src/threshold_ring/protocol.rs +++ b/iris-mpc-gpu/src/threshold_ring/protocol.rs @@ -717,7 +717,7 @@ impl Circuits { fn arithmetic_xor_many_pre_assign( &mut self, x1: &mut ChunkShareView, - x2: &ChunkShare, + x2: &ChunkShareView, idx: usize, streams: &[CudaStream], ) { @@ -1205,8 +1205,8 @@ impl Circuits { fn split_for_arithmetic_xor( &mut self, inp: &[ChunkShareView], - x1: &mut [ChunkShare], - x2: &mut [ChunkShare], + x1: &mut [ChunkShareView], + x2: &mut [ChunkShareView], x3: &mut [ChunkShareView], streams: &[CudaStream], ) { @@ -1264,23 +1264,18 @@ impl Circuits { // buffer is aligned properly for the transmute let mut x1 = Vec::with_capacity(x1_.len()); for (idx, x) in x1_.iter().enumerate() { - let a = self.devs[idx].alloc_zeros(64 * self.chunk_size).unwrap(); - let b = self.devs[idx].alloc_zeros(64 * self.chunk_size).unwrap(); - // let a: CudaView = unsafe { x.a.transmute(64 * self.chunk_size).unwrap() - // }; let b: CudaView = unsafe { x.b.transmute(64 * - // self.chunk_size).unwrap() }; - // let view = ChunkShareView { a, b }; - let view = ChunkShare { a, b }; + let a: CudaView = unsafe { x.a.transmute(64 * self.chunk_size).unwrap() }; + let b: CudaView = unsafe { x.b.transmute(64 * self.chunk_size).unwrap() }; + let view = ChunkShareView { a, b }; + // let view = ChunkShare { a, b }; x1.push(view); } let mut x2 = Vec::with_capacity(x2_.len()); for (idx, x) in x2_.iter().enumerate() { - let a = self.devs[idx].alloc_zeros(64 * self.chunk_size).unwrap(); - let b = self.devs[idx].alloc_zeros(64 * self.chunk_size).unwrap(); - // let a: CudaView = unsafe { x.a.transmute(64 * self.chunk_size).unwrap() - // }; let b: CudaView = unsafe { x.b.transmute(64 * - // self.chunk_size).unwrap() }; let view = ChunkShareView { a, b }; - let view = ChunkShare { a, b }; + let a: CudaView = unsafe { x.a.transmute(64 * self.chunk_size).unwrap() }; + let b: CudaView = unsafe { x.b.transmute(64 * self.chunk_size).unwrap() }; + let view = ChunkShareView { a, b }; + // let view = ChunkShare { a, b }; x2.push(view); } @@ -2536,13 +2531,6 @@ impl Circuits { // Result is in the first bit of the result buffer let result = self.take_result_buffer(); - let test = - dtoh_on_stream_sync(&result[0].a.slice(0..16), &self.devs[0], &streams[0]).unwrap(); - tracing::warn!("result.a: id: {} {:?}", self.prev_id, test); - let test = - dtoh_on_stream_sync(&result[0].b.slice(0..16), &self.devs[0], &streams[0]).unwrap(); - tracing::warn!("result.b: id: {} {:?}", self.prev_id, test); - let mut bits = Vec::with_capacity(self.n_devices); for r in result.iter() { // Result is in the first bit of the input