Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
philsippl committed Sep 10, 2024
1 parent 5fec298 commit 8f582cd
Showing 1 changed file with 97 additions and 84 deletions.
181 changes: 97 additions & 84 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ mod tests {
galois_engine::degree4::{GaloisRingIrisCodeShare, GaloisRingTrimmedMaskCodeShare},
iris_db::db::IrisDB,
};
use itertools::Itertools;
use ndarray::Array2;
use num_traits::FromPrimitive;
use rand::{rngs::StdRng, Rng, SeedableRng};
Expand Down Expand Up @@ -782,6 +783,18 @@ mod tests {
.collect()
}

fn shard_db(db: &[u16], n_shards: usize) -> Vec<u16> {
let mut res: Vec<Vec<u16>> = vec![vec![]; n_shards];
db.iter()
.chunks(WIDTH)
.into_iter()
.enumerate()
.for_each(|(i, chunk)| {
res[i % n_shards].extend(chunk);
});
res.into_iter().flatten().collect::<Vec<_>>()
}

/// Test to verify the matmul operation for random matrices in the field
#[test]
#[cfg(feature = "gpu_dependent")]
Expand Down Expand Up @@ -824,7 +837,7 @@ mod tests {
engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams);
device_manager.await_streams(&streams);

let a_nda = random_ndarray::<u16>(db.clone(), DB_SIZE, WIDTH);
let a_nda = random_ndarray::<u16>(shard_db(&db, n_devices), DB_SIZE, WIDTH);
let b_nda = random_ndarray::<u16>(query.clone(), QUERY_SIZE, WIDTH);
let c_nda = a_nda.dot(&b_nda.t());

Expand All @@ -848,95 +861,95 @@ mod tests {
.cloned()
.collect();

assert_eq!(selected_elements[0..10], gpu_result[0..10]);
assert_eq!(selected_elements, gpu_result);
}
}

// /// Checks that the result of a matmul of the original data equals the
// /// reconstructed result of individual matmuls on the shamir shares.
// #[test]
// #[cfg(feature = "gpu_dependent")]
// fn check_shared_matmul() {
// let mut rng = StdRng::seed_from_u64(RNG_SEED);
// let device_manager = Arc::new(DeviceManager::init());
// let n_devices = device_manager.device_count();

// let db = IrisDB::new_random_par(DB_SIZE, &mut rng);

// let mut gpu_result = [
// vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
// vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
// vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
// ];

// for i in 0..3 {
// let device_manager = Arc::clone(&device_manager);

// let codes_db = db
// .db
// .iter()
// .flat_map(|iris| {
// GaloisRingIrisCodeShare::encode_mask_code(
// &iris.mask,
// &mut StdRng::seed_from_u64(RNG_SEED),
// )[i]
// .coefs
// })
// .collect::<Vec<_>>();
/// Checks that the result of a matmul of the original data equals the
/// reconstructed result of individual matmuls on the shamir shares.
#[test]
#[cfg(feature = "gpu_dependent")]
fn check_shared_matmul() {
let mut rng = StdRng::seed_from_u64(RNG_SEED);
let device_manager = Arc::new(DeviceManager::init());
let n_devices = device_manager.device_count();

// let querys = db.db[0..QUERY_SIZE]
// .iter()
// .flat_map(|iris| {
// let mut shares =
// GaloisRingIrisCodeShare::encode_mask_code(
// &iris.mask, &mut StdRng::seed_from_u64(RNG_SEED),
// );
// shares[i].preprocess_iris_code_query_share();
// shares[i].coefs
// })
// .collect::<Vec<_>>();
let db = IrisDB::new_random_par(DB_SIZE, &mut rng);

let mut gpu_result = [
vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
];

for i in 0..3 {
let device_manager = Arc::clone(&device_manager);

let codes_db = db
.db
.iter()
.flat_map(|iris| {
GaloisRingIrisCodeShare::encode_mask_code(
&iris.mask,
&mut StdRng::seed_from_u64(RNG_SEED),
)[i]
.coefs
})
.collect::<Vec<_>>();

let querys = db.db[0..QUERY_SIZE]
.iter()
.flat_map(|iris| {
let mut shares = GaloisRingIrisCodeShare::encode_mask_code(
&iris.mask,
&mut StdRng::seed_from_u64(RNG_SEED),
);
shares[i].preprocess_iris_code_query_share();
shares[i].coefs
})
.collect::<Vec<_>>();

// let mut engine = ShareDB::init(
// 0,
// device_manager.clone(),
// DB_SIZE,
// QUERY_SIZE,
// IRIS_CODE_LENGTH,
// ([0u32; 8], [0u32; 8]),
// vec![],
// );
// let preprocessed_query = preprocess_query(&querys);
// let streams = device_manager.fork_streams();
// let blass = device_manager.create_cublas(&streams);
// let preprocessed_query = device_manager
// .htod_transfer_query(&preprocessed_query, &streams,
// QUERY_SIZE, IRIS_CODE_LENGTH) .unwrap();
// let query_sums = engine.query_sums(&preprocessed_query, &streams,
// &blass); let mut db_slices = engine.alloc_db(DB_SIZE);
// engine.load_full_db(&mut db_slices, &codes_db);
// let db_sizes = vec![DB_SIZE; n_devices];
let mut engine = ShareDB::init(
0,
device_manager.clone(),
DB_SIZE,
QUERY_SIZE,
IRIS_CODE_LENGTH,
([0u32; 8], [0u32; 8]),
vec![],
);
let preprocessed_query = preprocess_query(&querys);
let streams = device_manager.fork_streams();
let blass = device_manager.create_cublas(&streams);
let preprocessed_query = device_manager
.htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE, IRIS_CODE_LENGTH)
.unwrap();
let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass);
let mut db_slices = engine.alloc_db(DB_SIZE);
engine.load_full_db(&mut db_slices, &codes_db);
let db_sizes = vec![DB_SIZE / n_devices; n_devices];

engine.dot(
&preprocessed_query,
&db_slices.code_gr,
&db_sizes,
0,
&streams,
&blass,
);
engine.dot_reduce(&query_sums, &db_slices.code_sums_gr, &db_sizes, 0, &streams);
device_manager.await_streams(&streams);
engine.fetch_results(&mut gpu_result[i], &db_sizes, 0);
}

// engine.dot(
// &preprocessed_query,
// &db_slices.code_gr,
// &db_sizes,
// 0,
// &streams,
// &blass,
// );
// engine.dot_reduce(&query_sums, &db_slices.code_sums_gr,
// &db_sizes, 0, &streams); device_manager.await_streams(&
// streams); engine.fetch_results(&mut gpu_result[i], &db_sizes,
// 0); }

// for i in 0..DB_SIZE * QUERY_SIZE / n_devices {
// assert_eq!(
// (gpu_result[0][i] + gpu_result[1][i] + gpu_result[2][i]),
// (db.db[i / (DB_SIZE / n_devices)].mask & db.db[i % (DB_SIZE /
// n_devices)].mask) .count_ones() as u16
// );
// }
// }
for i in 0..DB_SIZE * QUERY_SIZE / n_devices {
assert_eq!(
(gpu_result[0][i] + gpu_result[1][i] + gpu_result[2][i]),
(db.db[i / (DB_SIZE / n_devices)].mask & db.db[i % (DB_SIZE / n_devices)].mask)
.count_ones() as u16
);
}
}

// /// Calculates the distances between a query and a shamir secret shared
// db /// and checks the result against reference plain implementation.
Expand Down

0 comments on commit 8f582cd

Please sign in to comment.