Skip to content

Commit

Permalink
add new test case for bucketing
Browse files Browse the repository at this point in the history
  • Loading branch information
rw0x0 committed Jan 13, 2025
1 parent 58f6b5b commit 20b2cd6
Show file tree
Hide file tree
Showing 2 changed files with 403 additions and 0 deletions.
42 changes: 42 additions & 0 deletions iris-mpc-gpu/src/threshold_ring/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2404,6 +2404,48 @@ impl Circuits {
// Result is in the first bit of the result buffer
}

// same as compare_threshold_masked_many, just via the functions used in the
// bucketing
pub fn compare_threshold_masked_many_bucket_functions(
&mut self,
code_dots: &[ChunkShareView<u16>],
mask_dots: &[ChunkShareView<u16>],
streams: &[CudaStream],
) {
const A: u64 = ((1. - 2. * iris_mpc_common::iris_db::iris::MATCH_THRESHOLD_RATIO)
* ((1 << 16) as f64)) as u64;

assert_eq!(self.n_devices, code_dots.len());
assert_eq!(self.n_devices, mask_dots.len());
for chunk in code_dots.iter().chain(mask_dots.iter()) {
assert!(chunk.len() % 64 == 0);
}

let x_ = Buffers::take_buffer(&mut self.buffers.lifted_shares);
let x1_ = Buffers::take_buffer(&mut self.buffers.lifted_shares_buckets1);
let x2_ = Buffers::take_buffer(&mut self.buffers.lifted_shares_buckets2);
let corrections_ = Buffers::take_buffer(&mut self.buffers.lifting_corrections);
let mut masks = Buffers::get_buffer_chunk(&x1_, 64 * self.chunk_size);
let mut codes = Buffers::get_buffer_chunk(&x2_, 64 * self.chunk_size);
let mut x = Buffers::get_buffer_chunk(&x_, 64 * self.chunk_size);
let mut corrections = Buffers::get_buffer_chunk(&corrections_, 128 * self.chunk_size);

self.lift_mpc(mask_dots, &mut masks, &mut corrections, streams);
self.finalize_lifts(&mut x, &mut codes, &corrections, code_dots, streams);
self.lifted_sub(&mut x, &masks, &codes, A as u32, streams);

self.lift_mul_sub(&mut x, &corrections, code_dots, streams);
self.extract_msb(&mut x, streams);

Buffers::return_buffer(&mut self.buffers.lifted_shares, x_);
Buffers::return_buffer(&mut self.buffers.lifted_shares_buckets1, x1_);
Buffers::return_buffer(&mut self.buffers.lifted_shares_buckets2, x2_);
Buffers::return_buffer(&mut self.buffers.lifting_corrections, corrections_);
self.buffers.check_buffers();

// Result is in the first bit of the result buffer
}

// input should be of size: n_devices * input_size
// Result is in the lowest bit of the result buffer on the first gpu
pub fn compare_threshold_masked_many_with_or_tree(
Expand Down
Loading

0 comments on commit 20b2cd6

Please sign in to comment.