From 784747aa57b30c3bc73b96d04cde19dfd8fc4a4c Mon Sep 17 00:00:00 2001 From: Roman Walch <9820846+rw0x0@users.noreply.github.com> Date: Mon, 13 Jan 2025 15:33:49 +0100 Subject: [PATCH] fix an error --- iris-mpc-gpu/src/server/actor.rs | 2 +- iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu | 8 ++++---- iris-mpc-gpu/src/threshold_ring/protocol.rs | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 3e63075ba..78457ad28 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -1146,7 +1146,7 @@ impl ServerActor { &match_distances_buffers_codes_view, &match_distances_buffers_masks_view, batch_streams, - &[24577], + &[16384], &mut self.buckets, ); diff --git a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu index 397ec57fc..1e0defbe7 100644 --- a/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu +++ b/iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu @@ -30,8 +30,8 @@ __device__ void arithmetic_xor_inner(T *res_a, T *lhs_a, T *lhs_b, T *rhs_a, T rhs_b_val = *rhs_b; T r1_val = *r1; T r2_val = *r2; - - T mul = (lhs_a_val * rhs_a_val) + (lhs_b_val * rhs_a_val) + + + T mul = (lhs_a_val * rhs_a_val) + (lhs_b_val * rhs_a_val) + (lhs_a_val * rhs_b_val) + r1_val - r2_val; *res_a = lhs_a_val + rhs_a_val - 2 * mul; } @@ -457,10 +457,10 @@ extern "C" __global__ void shared_lifted_sub(U32 *mask_a, U32 *mask_b, lifted_sub(&mask_b[i], &code_b[i], &output_b[i], a); switch (id) { case 0: - mask_a[i] -= 1; // Transforms the <= into < + output_a[i] -= 1; // Transforms the <= into < break; case 1: - mask_b[i] -= 1; // Transforms the <= into < + output_b[i] -= 1; // Transforms the <= into < break; default: break; diff --git a/iris-mpc-gpu/src/threshold_ring/protocol.rs b/iris-mpc-gpu/src/threshold_ring/protocol.rs index 7fa7bd1ce..5e126bbe9 100644 --- a/iris-mpc-gpu/src/threshold_ring/protocol.rs +++ b/iris-mpc-gpu/src/threshold_ring/protocol.rs @@ -2414,7 +2414,7 @@ impl Circuits { streams: &[CudaStream], ) { const A: u64 = ((1. - 2. * iris_mpc_common::iris_db::iris::MATCH_THRESHOLD_RATIO) - * ((1 << 16) as f64)) as u64; + * ((1u64 << 16) as f64)) as u64; assert_eq!(self.n_devices, code_dots.len()); assert_eq!(self.n_devices, mask_dots.len());