Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Jan 11, 2025
1 parent 05528e5 commit 0b9a511
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 10 deletions.
14 changes: 14 additions & 0 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,4 +460,18 @@ impl DistanceComparator {
.map(|i| self.device_manager.device(i).alloc_zeros(1).unwrap())
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_buckets(&self, n_buckets: usize) -> ChunkShare<u32> {
let a = self
.device_manager
.device(0)
.alloc_zeros(n_buckets)
.unwrap();
let b = self
.device_manager
.device(0)
.alloc_zeros(n_buckets)
.unwrap();
ChunkShare::new(a, b)
}
}
48 changes: 40 additions & 8 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl ServerActorHandle {
const DB_CHUNK_SIZE: usize = 1 << 15;
const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e676ae3091"; // Random 32 byte salt
const SUPERMATCH_THRESHOLD: usize = 4_000;
const MIN_MATCH_DISTANCES: usize = 100_000;
const MATCH_DISTANCES_BUFFER_SIZE: usize = 1 << 15;

pub struct ServerActor {
job_queue: mpsc::Receiver<ServerJob>,
Expand All @@ -88,6 +88,7 @@ pub struct ServerActor {
batch_masks_engine: ShareDB,
phase2: Circuits,
phase2_batch: Circuits,
phase2_buckets: Circuits,
distance_comparator: DistanceComparator,
comms: Vec<Arc<NcclComm>>,
// DB slices
Expand Down Expand Up @@ -122,6 +123,7 @@ pub struct ServerActor {
match_distances_buffer_masks_right: Vec<ChunkShare<u16>>,
match_distances_counter_left: Vec<CudaSlice<u32>>,
match_distances_counter_right: Vec<CudaSlice<u32>>,
buckets: ChunkShare<u32>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down Expand Up @@ -327,8 +329,8 @@ impl ServerActor {

let phase2_buckets = Circuits::new(
party_id,
n_queries,
n_queries / 64,
MATCH_DISTANCES_BUFFER_SIZE,
MATCH_DISTANCES_BUFFER_SIZE / 64,
next_chacha_seeds(chacha_seeds)?,
device_manager.clone(),
comms.clone(),
Expand Down Expand Up @@ -368,15 +370,16 @@ impl ServerActor {

// Buffers and counters for match distribution
let match_distances_buffer_codes_left =
distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
let match_distances_buffer_codes_right =
distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
let match_distances_buffer_masks_left =
distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
let match_distances_buffer_masks_right =
distance_comparator.prepare_match_distances_buffer(1_000_000); // TODO
distance_comparator.prepare_match_distances_buffer(MATCH_DISTANCES_BUFFER_SIZE);
let match_distances_counter_left = distance_comparator.prepare_match_distances_counter();
let match_distances_counter_right = distance_comparator.prepare_match_distances_counter();
let buckets = distance_comparator.prepare_match_distances_buckets(1); // TODO

for dev in device_manager.devices() {
dev.synchronize().unwrap();
Expand All @@ -390,6 +393,8 @@ impl ServerActor {
masks_engine,
phase2,
phase2_batch,
phase2_buckets,
buckets,
distance_comparator,
batch_codes_engine,
batch_masks_engine,
Expand Down Expand Up @@ -1123,8 +1128,35 @@ impl ServerActor {

tracing::info!("Matching distances collected: {}", total_distance_counter);

if total_distance_counter < MIN_MATCH_DISTANCES {
if total_distance_counter >= 10 {
// TODO
tracing::info!("Collected enough match distances, starting bucket calculation");

let match_distances_buffers_codes_view = match_distances_buffers_codes
.iter()
.map(|x| x.as_view())
.collect::<Vec<_>>();

let match_distances_buffers_masks_view = match_distances_buffers_masks
.iter()
.map(|x| x.as_view())
.collect::<Vec<_>>();

self.phase2_buckets.compare_multiple_thresholds(
&match_distances_buffers_codes_view,
&match_distances_buffers_masks_view,
batch_streams,
&[24577],
&mut self.buckets,
);

let buckets = self
.phase2_buckets
.open_buckets(&mut self.buckets, batch_streams);

tracing::info!("BUCKETRESULT: {:?}", buckets);

// TODO: reset counter
}

// ---- START BATCH DEDUP ----
Expand Down
31 changes: 29 additions & 2 deletions iris-mpc-gpu/src/threshold_ring/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2465,10 +2465,11 @@ impl Circuits {
// Sum all elements in x to get the result in the first 32 bit word on each GPU
self.collapse_sum(&mut x, streams);
// Get data onto the first GPU
self.collect_graphic_result_u32(&mut x, streams);
if self.n_devices > 1 {
self.collect_graphic_result_u32(&mut x, streams);
}
// Accumulate first result onto bucket
self.collapse_sum_on_gpu(buckets, &x, self.n_devices, bucket_idx, 0, streams);

self.return_result_buffer(result);
}

Expand All @@ -2478,4 +2479,30 @@ impl Circuits {
Buffers::return_buffer(&mut self.buffers.lifting_corrections, corrections_);
self.buffers.check_buffers();
}

pub fn open_buckets(
&mut self,
buckets: &mut ChunkShare<u32>,
streams: &[CudaStream],
) -> Vec<u32> {
let a = dtoh_on_stream_sync(&buckets.a, &self.devs[0], &streams[0]).unwrap();
let b = dtoh_on_stream_sync(&buckets.b, &self.devs[0], &streams[0]).unwrap();
let mut res = buckets.as_view();

result::group_start().unwrap();
self.comms[0]
.send_view(&res.b, self.next_id, &streams[0])
.unwrap();
self.comms[0]
.receive_view(&mut res.a, self.prev_id, &streams[0])
.unwrap();
result::group_end().unwrap();

let c = dtoh_on_stream_sync(&res.a, &self.devs[0], &streams[0]).unwrap();
a.iter()
.zip(b.iter())
.zip(c.iter())
.map(|((&a, &b), &c)| a.wrapping_add(b).wrapping_add(c))
.collect()
}
}

0 comments on commit 0b9a511

Please sign in to comment.