Skip to content

Commit

Permalink
[POP-2042] Anonymized stats through buckets (#878)
Browse files Browse the repository at this point in the history
* start to add code for bucket comparison

* progress on buckets

* .

* .

* .

* .

* .

* send/receive u32 with chacah encryption

* bitinject

* compare multiple thresholds function done, untested

* also make <=

* fix: Also fix the threshold comparison

* wip: keep results around

* wip

* clippy

* wip

* fix: ignore phantom matchers

* wip

* wip

* fix: wrong kernel for assign_u32 function

* fix: use correct buffer in buckets-GPU function

* feat: adapt threshold test for new test strucuture

* minor fix

* .

* another testcase adapted

* another testcase adapted

* .

* another testcase adapted

* .

* another testcase adapted

* .

* .

* add new test case for bucketing

* .

* .

* fix an error

* add another test

* .

* add buckets test

* .

* .

* wip

* use the open_bucket function in the testcase

* .

* make open_buckets not overwrite input

* function for threshold translation

* .

* match_distances_counter_idx

* minor

* fix an int/size_t error for kernels!

* sort results for  consistency across nodes

* clean + clippy

* cleanup and fixes

* clippy

* remove keeping two results around

* up

* add synchronize streams after loading to GPU in gpu_dependant testcases

* clippy fix

* debug for testcases

* .

* .

* fix the len issues in sending/receiving

* fix of fix

* remove or-tree test

* fix?

* buckets as config

* Ps/buckets config improvements (#965)

* buckets as config

* stage config

* revert nccl changes

* custom image build

* custom image deployment

* await streams

* prod config

* clippy

* revert staging deployment

---------

Co-authored-by: Roman Walch <[email protected]>
Co-authored-by: Carlo Mazzaferro <[email protected]>
Co-authored-by: Carlo Mazzaferro <[email protected]>
  • Loading branch information
4 people authored Jan 28, 2025
1 parent c04b7f7 commit 29c1c9d
Show file tree
Hide file tree
Showing 25 changed files with 3,260 additions and 563 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/temp-branch-build-and-push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Branch - Build and push docker image
on:
push:
branches:
- "reduce-size-docker-image"
- "ps/buckets"

concurrency:
group: '${{ github.workflow }} @ ${{ github.event.pull_request.head.label || github.head_ref || github.ref }}'
Expand Down
5 changes: 2 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
6 changes: 6 additions & 0 deletions deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ env:
- name: SMPC__MAX_BATCH_SIZE
value: "64"

- name: SMPC__MATCH_DISTANCES_BUFFER_SIZE
value: "128"

- name: SMPC__N_BUCKETS
value: "10"

- name: SMPC__SERVICE__METRICS__HOST
valueFrom:
fieldRef:
Expand Down
14 changes: 14 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ pub struct Config {

#[serde(default)]
pub fixed_shared_secrets: bool,

#[serde(default = "default_match_distances_buffer_size")]
pub match_distances_buffer_size: usize,

#[serde(default = "default_n_buckets")]
pub n_buckets: usize,
}

fn default_load_chunks_parallelism() -> usize {
Expand Down Expand Up @@ -145,6 +151,14 @@ fn default_db_load_safety_overlap_seconds() -> i64 {
60
}

fn default_match_distances_buffer_size() -> usize {
1 << 19
}

fn default_n_buckets() -> usize {
375
}

impl Config {
pub fn load_config(prefix: &str) -> eyre::Result<Config> {
let settings = config::Config::builder();
Expand Down
5 changes: 4 additions & 1 deletion iris-mpc-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ repository.workspace = true

[dependencies]
bincode = "1.3.3"
cudarc = { version = "0.13.3", features = ["cuda-12020", "nccl"] }
cudarc = { git = "https://github.com/worldcoin/cudarc-fork.git", features = [
"cuda-12020",
"nccl",
] }
eyre.workspace = true
tracing.workspace = true
bytemuck.workspace = true
Expand Down
12 changes: 3 additions & 9 deletions iris-mpc-gpu/src/bin/nccl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,9 @@ async fn main() -> eyre::Result<()> {
for i in 0..n_devices {
devs[i].bind_to_thread().unwrap();

comms[i]
.broadcast(slices[i].as_ref(), &mut slices1[i], 0)
.unwrap();
comms[i]
.broadcast(slices[i].as_ref(), &mut slices2[i], 1)
.unwrap();
comms[i]
.broadcast(slices[i].as_ref(), &mut slices3[i], 2)
.unwrap();
comms[i].broadcast(&slices[i], &mut slices1[i], 0).unwrap();
comms[i].broadcast(&slices[i], &mut slices2[i], 1).unwrap();
comms[i].broadcast(&slices[i], &mut slices3[i], 2).unwrap();
}

for dev in devs.iter() {
Expand Down
154 changes: 148 additions & 6 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
use super::ROTATIONS;
use crate::helpers::{
device_manager::DeviceManager, launch_config_from_elements_and_threads,
DEFAULT_LAUNCH_CONFIG_THREADS,
use crate::{
helpers::{
device_manager::DeviceManager, launch_config_from_elements_and_threads,
DEFAULT_LAUNCH_CONFIG_THREADS,
},
threshold_ring::protocol::{ChunkShare, ChunkShareView},
};
use cudarc::{
driver::{CudaFunction, CudaSlice, CudaStream, CudaView, LaunchAsync},
driver::{
result::{launch_kernel, memset_d8_sync},
sys, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceSlice, LaunchAsync,
},
nvrtc::compile_ptx,
};
use std::{cmp::min, sync::Arc};
use std::{cmp::min, ffi::c_void, sync::Arc};

const PTX_SRC: &str = include_str!("kernel.cu");
const OPEN_RESULTS_FUNCTION: &str = "openResults";
const OPEN_RESULTS_BATCH_FUNCTION: &str = "openResultsBatch";
const MERGE_DB_RESULTS_FUNCTION: &str = "mergeDbResults";
const MERGE_BATCH_RESULTS_FUNCTION: &str = "mergeBatchResults";
const ALL_MATCHES_LEN: usize = 256;

pub struct DistanceComparator {
pub device_manager: Arc<DeviceManager>,
pub open_kernels: Vec<CudaFunction>,
pub open_batch_kernels: Vec<CudaFunction>,
pub merge_db_kernels: Vec<CudaFunction>,
pub merge_batch_kernels: Vec<CudaFunction>,
pub query_length: usize,
Expand All @@ -37,6 +45,7 @@ impl DistanceComparator {
pub fn init(query_length: usize, device_manager: Arc<DeviceManager>) -> Self {
let ptx = compile_ptx(PTX_SRC).unwrap();
let mut open_kernels: Vec<CudaFunction> = Vec::new();
let mut open_batch_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut opened_results = vec![];
Expand All @@ -58,12 +67,15 @@ impl DistanceComparator {
device
.load_ptx(ptx.clone(), "", &[
OPEN_RESULTS_FUNCTION,
OPEN_RESULTS_BATCH_FUNCTION,
MERGE_DB_RESULTS_FUNCTION,
MERGE_BATCH_RESULTS_FUNCTION,
])
.unwrap();

let open_results_function = device.get_func("", OPEN_RESULTS_FUNCTION).unwrap();
let open_results_batch_function =
device.get_func("", OPEN_RESULTS_BATCH_FUNCTION).unwrap();
let merge_db_results_function = device.get_func("", MERGE_DB_RESULTS_FUNCTION).unwrap();
let merge_batch_results_function =
device.get_func("", MERGE_BATCH_RESULTS_FUNCTION).unwrap();
Expand All @@ -90,13 +102,15 @@ impl DistanceComparator {
);

open_kernels.push(open_results_function);
open_batch_kernels.push(open_results_batch_function);
merge_db_kernels.push(merge_db_results_function);
merge_batch_kernels.push(merge_batch_results_function);
}

Self {
device_manager,
open_kernels,
open_batch_kernels,
merge_db_kernels,
merge_batch_kernels,
query_length,
Expand All @@ -115,6 +129,85 @@ impl DistanceComparator {

#[allow(clippy::too_many_arguments)]
pub fn open_results(
&self,
results1: &[CudaView<u64>],
results2: &[CudaView<u64>],
results3: &[CudaView<u64>],
matches_bitmap: &[CudaSlice<u64>],
db_sizes: &[usize],
real_db_sizes: &[usize],
offset: usize,
total_db_sizes: &[usize],
ignore_db_results: &[bool],
match_distances_buffers_codes: &[ChunkShare<u16>],
match_distances_buffers_masks: &[ChunkShare<u16>],
match_distances_counters: &[CudaSlice<u32>],
match_distances_indices: &[CudaSlice<u32>],
code_dots: &[ChunkShareView<u16>],
mask_dots: &[ChunkShareView<u16>],
batch_size: usize,
max_bucket_distances: usize,
streams: &[CudaStream],
) {
for i in 0..self.device_manager.device_count() {
// Those correspond to 0 length dbs, which were just artificially increased to
// length 1 to avoid division by zero in the kernel
if ignore_db_results[i] {
continue;
}
let num_elements = (db_sizes[i] * self.query_length).div_ceil(64);
let threads_per_block = DEFAULT_LAUNCH_CONFIG_THREADS; // ON CHANGE: sync with kernel
let cfg = launch_config_from_elements_and_threads(
num_elements as u32,
threads_per_block,
&self.device_manager.devices()[i],
);
self.device_manager.device(i).bind_to_thread().unwrap();

let ptr_param = |ptr: *const sys::CUdeviceptr| ptr as *mut c_void;
let usize_param = |val: &usize| val as *const usize as *mut _;

let params = &mut [
// Results arrays
ptr_param(results1[i].device_ptr()),
ptr_param(results2[i].device_ptr()),
ptr_param(results3[i].device_ptr()),
ptr_param(matches_bitmap[i].device_ptr()),
usize_param(&db_sizes[i]),
usize_param(&(batch_size * ROTATIONS)),
usize_param(&offset),
usize_param(&num_elements),
usize_param(&real_db_sizes[i]),
usize_param(&total_db_sizes[i]),
ptr_param(match_distances_buffers_codes[i].a.device_ptr()),
ptr_param(match_distances_buffers_codes[i].b.device_ptr()),
ptr_param(match_distances_buffers_masks[i].a.device_ptr()),
ptr_param(match_distances_buffers_masks[i].b.device_ptr()),
ptr_param(match_distances_counters[i].device_ptr()),
ptr_param(match_distances_indices[i].device_ptr()),
ptr_param(code_dots[i].a.device_ptr()),
ptr_param(code_dots[i].b.device_ptr()),
ptr_param(mask_dots[i].a.device_ptr()),
ptr_param(mask_dots[i].b.device_ptr()),
usize_param(&max_bucket_distances),
];

unsafe {
launch_kernel(
self.open_kernels[i].cu_function(),
cfg.grid_dim,
cfg.block_dim,
0,
streams[i].stream,
params,
)
.unwrap();
}
}
}

#[allow(clippy::too_many_arguments)]
pub fn open_batch_results(
&self,
results1: &[CudaView<u64>],
results2: &[CudaView<u64>],
Expand Down Expand Up @@ -143,7 +236,7 @@ impl DistanceComparator {
self.device_manager.device(i).bind_to_thread().unwrap();

unsafe {
self.open_kernels[i]
self.open_batch_kernels[i]
.clone()
.launch_on_stream(
&streams[i],
Expand Down Expand Up @@ -347,4 +440,53 @@ impl DistanceComparator {
})
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_buffer(&self, max_size: usize) -> Vec<ChunkShare<u16>> {
(0..self.device_manager.device_count())
.map(|i| {
let a = self.device_manager.device(i).alloc_zeros(max_size).unwrap();
let b = self.device_manager.device(i).alloc_zeros(max_size).unwrap();

self.device_manager.device(i).bind_to_thread().unwrap();
unsafe {
memset_d8_sync(*a.device_ptr(), 0xff, a.num_bytes()).unwrap();
memset_d8_sync(*b.device_ptr(), 0xff, b.num_bytes()).unwrap();
}

ChunkShare::new(a, b)
})
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_counter(&self) -> Vec<CudaSlice<u32>> {
(0..self.device_manager.device_count())
.map(|i| self.device_manager.device(i).alloc_zeros(1).unwrap())
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_index(&self, max_size: usize) -> Vec<CudaSlice<u32>> {
(0..self.device_manager.device_count())
.map(|i| {
let a = self.device_manager.device(i).alloc_zeros(max_size).unwrap();
unsafe {
memset_d8_sync(*a.device_ptr(), 0xff, a.num_bytes()).unwrap();
}
a
})
.collect::<Vec<_>>()
}

pub fn prepare_match_distances_buckets(&self, n_buckets: usize) -> ChunkShare<u32> {
let a = self
.device_manager
.device(0)
.alloc_zeros(n_buckets)
.unwrap();
let b = self
.device_manager
.device(0)
.alloc_zeros(n_buckets)
.unwrap();
ChunkShare::new(a, b)
}
}
38 changes: 37 additions & 1 deletion iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ extern "C" __global__ void matmul_correct_and_reduce(int *c, unsigned short *out
}
}

extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t chunkLength, size_t queryLength, size_t offset, size_t numElements, size_t realChunkLen, size_t totalDbLen)
extern "C" __global__ void openResultsBatch(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t chunkLength, size_t queryLength, size_t offset, size_t numElements, size_t realChunkLen, size_t totalDbLen)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand All @@ -51,6 +51,42 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon
}
}

extern "C" __global__ void openResults(unsigned long long *result1, unsigned long long *result2, unsigned long long *result3, unsigned long long *output, size_t chunkLength, size_t queryLength, size_t offset, size_t numElements, size_t realChunkLen, size_t totalDbLen, unsigned short *match_distances_buffer_codes_a, unsigned short *match_distances_buffer_codes_b, unsigned short *match_distances_buffer_masks_a, unsigned short *match_distances_buffer_masks_b, unsigned int *match_distances_counter, unsigned int *match_distances_indices, unsigned short *code_dots_a, unsigned short *code_dots_b, unsigned short *mask_dots_a, unsigned short *mask_dots_b, size_t max_bucket_distances)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
{
unsigned long long result = result1[idx] ^ result2[idx] ^ result3[idx];
for (int i = 0; i < 64; i++)
{
unsigned int queryIdx = (idx * 64 + i) / chunkLength;
unsigned int dbIdx = (idx * 64 + i) % chunkLength;
bool match = (result & (1ULL << i));

// Check if we are out of bounds for the query or db
if (queryIdx >= queryLength || dbIdx >= realChunkLen || !match)
{
continue;
}

// Save the corresponding code and mask dots for later (match distributions)
unsigned int match_distances_counter_idx = atomicAdd(&match_distances_counter[0], 1);
if (match_distances_counter_idx < max_bucket_distances)
{
match_distances_indices[match_distances_counter_idx] = idx * 64 + i;
match_distances_buffer_codes_a[match_distances_counter_idx] = code_dots_a[idx * 64 + i];
match_distances_buffer_codes_b[match_distances_counter_idx] = code_dots_b[idx * 64 + i];
match_distances_buffer_masks_a[match_distances_counter_idx] = mask_dots_a[idx * 64 + i];
match_distances_buffer_masks_b[match_distances_counter_idx] = mask_dots_b[idx * 64 + i];
}

// Mark which results are matches with a bit in the output
unsigned int outputIdx = totalDbLen * (queryIdx / ALL_ROTATIONS) + dbIdx + offset;
atomicOr(&output[outputIdx / 64], (1ULL << (outputIdx % 64)));
}
}
}

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down
Loading

0 comments on commit 29c1c9d

Please sign in to comment.