Skip to content

Commit

Permalink
fix the invalid transformation from < to <= in the threshold check (#685
Browse files Browse the repository at this point in the history
)

* add exact threshold check

* also mutate query db

* make it fail

* fix kernel calculation

* make test random around threshold again

* make the test variation smaller

* test fix

---------

Co-authored-by: Daniel Kales <[email protected]>
  • Loading branch information
philsippl and dkales authored Nov 15, 2024
1 parent a73bfce commit 2b1e968
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 9 deletions.
4 changes: 2 additions & 2 deletions iris-mpc-gpu/src/threshold_ring/cuda/kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,10 @@ extern "C" __global__ void shared_lift_mul_sub(U32 *mask_a, U32 *mask_b,
lift_mul_sub(&mask_b[i], &mask_corr_b[i], &mask_corr_b[i + n], &code_b[i]);
switch (id) {
case 0:
mask_a[i] += 1; // Transforms the <= into <
mask_a[i] -= 1; // Transforms the <= into <
break;
case 1:
mask_b[i] += 1; // Transforms the <= into <
mask_b[i] -= 1; // Transforms the <= into <
break;
default:
break;
Expand Down
60 changes: 53 additions & 7 deletions iris-mpc-gpu/tests/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod e2e_test {
use eyre::Result;
use iris_mpc_common::{
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
iris_db::{db::IrisDB, iris::IrisCode},
iris_db::{
db::IrisDB,
iris::{IrisCode, IrisCodeArray},
},
};
use iris_mpc_gpu::{
helpers::device_manager::DeviceManager,
Expand All @@ -23,10 +26,16 @@ mod e2e_test {
const NUM_BATCHES: usize = 10;
const MAX_BATCH_SIZE: usize = 64;
const MAX_DELETIONS_PER_BATCH: usize = 10;
const THRESHOLD_ABSOLUTE: usize = 4800; // 0.375 * 12800

fn generate_db(party_id: usize) -> Result<(Vec<u16>, Vec<u16>)> {
let mut rng = StdRng::seed_from_u64(DB_RNG_SEED);
let db = IrisDB::new_random_par(DB_SIZE, &mut rng);
let mut db = IrisDB::new_random_par(DB_SIZE, &mut rng);

// Set the masks to all 1s for the first 10%
for i in 0..DB_SIZE / 10 {
db.db[i].mask = IrisCodeArray::ONES;
}

let codes_db = db
.db
Expand Down Expand Up @@ -193,7 +202,12 @@ mod e2e_test {

// make a test query and send it to server

let db = IrisDB::new_random_par(DB_SIZE, &mut StdRng::seed_from_u64(DB_RNG_SEED));
let mut db = IrisDB::new_random_par(DB_SIZE, &mut StdRng::seed_from_u64(DB_RNG_SEED));

// Set the masks to all 1s for the first 10%
for i in 0..DB_SIZE / 10 {
db.db[i].mask = IrisCodeArray::ONES;
}

let mut rng = StdRng::seed_from_u64(INTERNAL_RNG_SEED);

Expand All @@ -202,6 +216,7 @@ mod e2e_test {
let mut responses: HashMap<u32, IrisCode> = HashMap::new();
let mut deleted_indices_buffer = vec![];
let mut deleted_indices: HashSet<u32> = HashSet::new();
let mut disallowed_queries = Vec::new();

for _ in 0..NUM_BATCHES {
let mut batch0 = BatchQuery::default();
Expand All @@ -212,11 +227,11 @@ mod e2e_test {
let request_id = Uuid::new_v4();
// Automatic random tests
let options = if responses.is_empty() {
2
} else if deleted_indices_buffer.is_empty() {
3
} else {
} else if deleted_indices_buffer.is_empty() {
4
} else {
5
};
let option = rng.gen_range(0..options);
let template = match option {
Expand All @@ -235,14 +250,45 @@ mod e2e_test {
db.db[db_index].clone()
}
2 => {
println!("Sending iris code on the threshold");
let db_index = loop {
let db_index = rng.gen_range(0..DB_SIZE / 10);
if !disallowed_queries.contains(&db_index) {
break db_index;
}
};
if deleted_indices.contains(&(db_index as u32)) {
continue;
}
let variation = rng.gen_range(-1..=1);
expected_results.insert(
request_id.to_string(),
if variation > 0 {
// we flip more than the threshold so this should not match
// however it would afterwards so we no longer pick it
disallowed_queries.push(db_index);
None
} else {
// we flip less or equal to than the threshold so this should match
Some(db_index as u32)
},
);
let mut code = db.db[db_index].clone();
assert!(code.mask == IrisCodeArray::ONES);
for i in 0..(THRESHOLD_ABSOLUTE as i32 + variation) as usize {
code.code.flip_bit(i);
}
code
}
3 => {
println!("Sending freshly inserted iris code");
let keys = responses.keys().collect::<Vec<_>>();
let idx = rng.gen_range(0..keys.len());
let iris_code = responses.get(keys[idx]).unwrap().clone();
expected_results.insert(request_id.to_string(), Some(*keys[idx]));
iris_code
}
3 => {
4 => {
println!("Sending deleted iris code");
let idx = rng.gen_range(0..deleted_indices_buffer.len());
let deleted_idx = deleted_indices_buffer[idx];
Expand Down

0 comments on commit 2b1e968

Please sign in to comment.