Skip to content

Commit

Permalink
expose partial matches (#534)
Browse files Browse the repository at this point in the history
* expose partial matches

* add temp branch build and push

* return partial matches regardless of final matches

* clean up

---------

Co-authored-by: Ertugrul Aypek <[email protected]>
  • Loading branch information
philsippl and eaypek-tfh authored Oct 28, 2024
1 parent d117a4d commit f49ac8f
Show file tree
Hide file tree
Showing 15 changed files with 191 additions and 32 deletions.
3 changes: 3 additions & 0 deletions deploy/prod/smpcv2-0-prod/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-0"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions deploy/prod/smpcv2-1-prod/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-1"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions deploy/prod/smpcv2-2-prod/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-2"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/smpcv2-0-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-0"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/smpcv2-1-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-1"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/smpcv2-2-stage/values-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ env:
- name: SMPC__SERVICE__METRICS__PREFIX
value: "smpcv2-2"

- name: SMPC__RETURN_PARTIAL_RESULTS
value: "true"

initContainer:
enabled: true
image: "amazon/aws-cli:2.17.62"
Expand Down
3 changes: 3 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ pub struct Config {
#[serde(default)]
pub fake_db_size: usize,

#[serde(default)]
pub return_partial_results: bool,

#[serde(default)]
pub disable_persistence: bool,
}
Expand Down
16 changes: 11 additions & 5 deletions iris-mpc-common/src/helpers/smpc_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,13 @@ impl UniquenessRequest {

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct UniquenessResult {
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub node_id: usize,
pub serial_id: Option<u32>,
pub is_match: bool,
pub signup_id: String,
pub matched_serial_ids: Option<Vec<u32>>,
pub matched_serial_ids_left: Option<Vec<u32>>,
pub matched_serial_ids_right: Option<Vec<u32>>,
}

impl UniquenessResult {
Expand All @@ -316,13 +318,17 @@ impl UniquenessResult {
is_match: bool,
signup_id: String,
matched_serial_ids: Option<Vec<u32>>,
matched_serial_ids_left: Option<Vec<u32>>,
matched_serial_ids_right: Option<Vec<u32>>,
) -> Self {
Self {
node_id,
serial_id,
is_match,
signup_id,
matched_serial_ids,
matched_serial_ids_left,
matched_serial_ids_right,
}
}
}
Expand Down
46 changes: 39 additions & 7 deletions iris-mpc-gpu/src/dot/distance_comparator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,26 @@ pub struct DistanceComparator {
pub final_results_init_host: Vec<u32>,
pub match_counters: Vec<CudaSlice<u32>>,
pub all_matches: Vec<CudaSlice<u32>>,
pub match_counters_left: Vec<CudaSlice<u32>>,
pub match_counters_right: Vec<CudaSlice<u32>>,
pub partial_results_left: Vec<CudaSlice<u32>>,
pub partial_results_right: Vec<CudaSlice<u32>>,
}

impl DistanceComparator {
pub fn init(query_length: usize, device_manager: Arc<DeviceManager>) -> Self {
let ptx = compile_ptx(PTX_SRC).unwrap();
let mut open_kernels = Vec::new();
let mut open_kernels: Vec<CudaFunction> = Vec::new();
let mut merge_db_kernels = Vec::new();
let mut merge_batch_kernels = Vec::new();
let mut opened_results = vec![];
let mut final_results = vec![];
let mut match_counters: Vec<CudaSlice<u32>> = vec![];
let mut all_matches: Vec<CudaSlice<u32>> = vec![];
let mut match_counters = vec![];
let mut match_counters_left = vec![];
let mut match_counters_right = vec![];
let mut all_matches = vec![];
let mut partial_results_left = vec![];
let mut partial_results_right = vec![];

let devices_count = device_manager.device_count();

Expand All @@ -63,11 +71,23 @@ impl DistanceComparator {
opened_results.push(device.htod_copy(results_init_host.clone()).unwrap());
final_results.push(device.htod_copy(final_results_init_host.clone()).unwrap());
match_counters.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_left.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
match_counters_right.push(device.alloc_zeros(query_length / ROTATIONS).unwrap());
all_matches.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_left.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);
partial_results_right.push(
device
.alloc_zeros(ALL_MATCHES_LEN * query_length / ROTATIONS)
.unwrap(),
);

open_kernels.push(open_results_function);
merge_db_kernels.push(merge_db_results_function);
Expand All @@ -85,7 +105,11 @@ impl DistanceComparator {
results_init_host,
final_results_init_host,
match_counters,
match_counters_left,
match_counters_right,
all_matches,
partial_results_left,
partial_results_right,
}
}

Expand Down Expand Up @@ -213,6 +237,10 @@ impl DistanceComparator {
num_elements as u64,
&self.match_counters[i],
&self.all_matches[i],
&self.match_counters_left[i],
&self.match_counters_right[i],
&self.partial_results_left[i],
&self.partial_results_right[i],
),
)
.unwrap();
Expand All @@ -233,26 +261,30 @@ impl DistanceComparator {
results
}

pub fn fetch_match_counters(&self) -> Vec<Vec<u32>> {
pub fn fetch_match_counters(&self, counters: &[CudaSlice<u32>]) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.match_counters[i])
.dtoh_sync_copy(&counters[i])
.unwrap(),
);
}
results
}

pub fn fetch_all_match_ids(&self, match_counters: Vec<Vec<u32>>) -> Vec<Vec<u32>> {
pub fn fetch_all_match_ids(
&self,
match_counters: Vec<Vec<u32>>,
matches: &[CudaSlice<u32>],
) -> Vec<Vec<u32>> {
let mut results = vec![];
for i in 0..self.device_manager.device_count() {
results.push(
self.device_manager
.device(i)
.dtoh_sync_copy(&self.all_matches[i])
.dtoh_sync_copy(&matches[i])
.unwrap(),
);
}
Expand Down
18 changes: 16 additions & 2 deletions iris-mpc-gpu/src/dot/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ extern "C" __global__ void openResults(unsigned long long *result1, unsigned lon
}
}

extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches)
extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft, unsigned long long *matchResultsRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *matchCounter, unsigned int *allMatches, unsigned int *matchCounterLeft, unsigned int *matchCounterRight, unsigned int *partialResultsLeft, unsigned int *partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand All @@ -67,6 +67,20 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
if (queryIdx >= queryLength || dbIdx >= dbLength)
continue;

// Check for partial results (only used for debugging)
if (matchLeft)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterLeft[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsLeft[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}
if (matchRight)
{
unsigned int queryMatchCounter = atomicAdd(&matchCounterRight[queryIdx], 1);
if (queryMatchCounter < MAX_MATCHES_LEN)
partialResultsRight[MAX_MATCHES_LEN * queryIdx + queryMatchCounter] = dbIdx;
}

// Current *AND* policy: only match, if both eyes match
if (matchLeft && matchRight)
{
Expand All @@ -79,7 +93,7 @@ extern "C" __global__ void mergeDbResults(unsigned long long *matchResultsLeft,
}
}

extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches)
extern "C" __global__ void mergeBatchResults(unsigned long long *matchResultsSelfLeft, unsigned long long *matchResultsSelfRight, unsigned int *finalResults, size_t queryLength, size_t dbLength, size_t numElements, unsigned int *__matchCounter, unsigned int *__allMatches, unsigned int *__matchCounterLeft, unsigned int *__matchCounterRight, unsigned int *__partialResultsLeft, unsigned int *__partialResultsRight)
{
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < numElements)
Expand Down
Loading

0 comments on commit f49ac8f

Please sign in to comment.