Skip to content

Commit

Permalink
Allow caller to change batch size (#224)
Browse files Browse the repository at this point in the history
* wip

* up

* wip

* dbg

* wip

* query fixed size

* up

* dbg

* clean up

* fix

* fix tests and clippy

* move DB_SIZE and DB_BUFFER out of actor

* fix clippy

* readd msg group id

* check if empty

* update batch size instantly

* cleaner

* batch_size in [1,MAX_BATCH_SIZE]

* clippy

* pr feedback
  • Loading branch information
philsippl authored Aug 12, 2024
1 parent 5d4d19d commit a09e816
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 72 deletions.
2 changes: 2 additions & 0 deletions iris-mpc-common/src/helpers/sqs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub struct SQSMessage {

#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct SMPCRequest {
// TODO: make this a message attribute, but the SQS message will anyways be refactored soon.
pub batch_size: Option<usize>,
pub request_id: String,
pub iris_code: String,
pub mask_code: String,
Expand Down
5 changes: 2 additions & 3 deletions iris-mpc-gpu/benches/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,15 @@ fn bench_memcpy(c: &mut Criterion) {
let preprocessed_query = preprocess_query(&query);
let streams = device_manager.fork_streams();
let blass = device_manager.create_cublas(&streams);
let db_slices = engine.load_db(&db, DB_SIZE, DB_SIZE, false);
let db_sizes = vec![DB_SIZE; 8];
let (db_slices, db_sizes) = engine.load_db(&db, DB_SIZE, DB_SIZE, false);

group.throughput(Throughput::Elements((DB_SIZE * QUERY_SIZE / 31) as u64));
group.sample_size(10);

group.bench_function(format!("matmul {} x {}", DB_SIZE, QUERY_SIZE), |b| {
b.iter(|| {
let preprocessed_query = device_manager
.htod_transfer_query(&preprocessed_query, &streams)
.htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE)
.unwrap();
let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass);
engine.dot(
Expand Down
60 changes: 37 additions & 23 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ impl ShareDB {
db_length: usize, // TODO: should handle different sizes for each device
max_db_length: usize,
alternating_chunks: bool,
) -> SlicedProcessedDatabase {
) -> (SlicedProcessedDatabase, Vec<usize>) {
let mut a1_host = db_entries
.par_iter()
.map(|&x: &u16| (x >> 8) as i8)
Expand Down Expand Up @@ -387,6 +387,18 @@ impl ShareDB {
alternating_chunks,
);

assert!(
db0.iter()
.zip(db1.iter())
.all(|(chunk0, chunk1)| chunk0.len() == chunk1.len()),
"db0 and db1 chunks must have the same length"
);

let db_lens = db0
.iter()
.map(|chunk| chunk.len() / IRIS_CODE_LENGTH)
.collect::<Vec<_>>();

let db1 = db1
.iter()
.map(|chunk| unsafe {
Expand Down Expand Up @@ -418,16 +430,19 @@ impl ShareDB {
dev.synchronize().unwrap();
}

SlicedProcessedDatabase {
code_gr: CudaVec2DSlicerRawPointer {
limb_0: db0,
limb_1: db1,
},
code_sums_gr: CudaVec2DSlicerU32 {
limb_0: db0_sums,
limb_1: db1_sums,
(
SlicedProcessedDatabase {
code_gr: CudaVec2DSlicerRawPointer {
limb_0: db0,
limb_1: db1,
},
code_sums_gr: CudaVec2DSlicerU32 {
limb_0: db0_sums,
limb_1: db1_sums,
},
},
}
db_lens,
)
}

pub fn query_sums(
Expand Down Expand Up @@ -841,7 +856,6 @@ mod tests {
let n_devices = device_manager.device_count();

let mut gpu_result = vec![0u16; DB_SIZE / n_devices * QUERY_SIZE];
let db_sizes = vec![DB_SIZE / n_devices; n_devices];

let mut engine = ShareDB::init(
0,
Expand All @@ -855,10 +869,10 @@ mod tests {
let streams = device_manager.fork_streams();
let blass = device_manager.create_cublas(&streams);
let preprocessed_query = device_manager
.htod_transfer_query(&preprocessed_query, &streams)
.htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE)
.unwrap();
let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass);
let db_slices = engine.load_db(&db, DB_SIZE, DB_SIZE, false);
let (db_slices, db_sizes) = engine.load_db(&db, DB_SIZE, DB_SIZE, false);

engine.dot(
&preprocessed_query,
Expand Down Expand Up @@ -915,8 +929,6 @@ mod tests {
vec![0u16; DB_SIZE * QUERY_SIZE / n_devices],
];

let db_sizes = vec![DB_SIZE / n_devices; n_devices];

for i in 0..3 {
let device_manager = Arc::clone(&device_manager);

Expand Down Expand Up @@ -955,10 +967,10 @@ mod tests {
let streams = device_manager.fork_streams();
let blass = device_manager.create_cublas(&streams);
let preprocessed_query = device_manager
.htod_transfer_query(&preprocessed_query, &streams)
.htod_transfer_query(&preprocessed_query, &streams, QUERY_SIZE)
.unwrap();
let query_sums = engine.query_sums(&preprocessed_query, &streams, &blass);
let db_slices = engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false);
let (db_slices, db_sizes) = engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false);
engine.dot(
&preprocessed_query,
&db_slices.code_gr,
Expand Down Expand Up @@ -991,8 +1003,6 @@ mod tests {

let db = IrisDB::new_random_par(DB_SIZE, &mut rng);

let db_sizes = vec![DB_SIZE / n_devices; n_devices];

let mut results_codes = [
vec![0u16; DB_SIZE / n_devices * QUERY_SIZE],
vec![0u16; DB_SIZE / n_devices * QUERY_SIZE],
Expand Down Expand Up @@ -1083,15 +1093,19 @@ mod tests {
let streams = device_manager.fork_streams();
let blass = device_manager.create_cublas(&streams);
let code_query = device_manager
.htod_transfer_query(&code_query, &streams)
.htod_transfer_query(&code_query, &streams, QUERY_SIZE)
.unwrap();
let mask_query = device_manager
.htod_transfer_query(&mask_query, &streams)
.htod_transfer_query(&mask_query, &streams, QUERY_SIZE)
.unwrap();
let code_query_sums = codes_engine.query_sums(&code_query, &streams, &blass);
let mask_query_sums = masks_engine.query_sums(&mask_query, &streams, &blass);
let code_db_slices = codes_engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false);
let mask_db_slices = codes_engine.load_db(&masks_db, DB_SIZE, DB_SIZE, false);
let (code_db_slices, db_sizes) =
codes_engine.load_db(&codes_db, DB_SIZE, DB_SIZE, false);
let (mask_db_slices, mask_db_sizes) =
codes_engine.load_db(&masks_db, DB_SIZE, DB_SIZE, false);

assert_eq!(db_sizes, mask_db_sizes);

codes_engine.dot(
&code_query,
Expand Down
13 changes: 7 additions & 6 deletions iris-mpc-gpu/src/helpers/device_manager.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::query_processor::{CudaVec2DSlicerU8, StreamAwareCudaSlice};
use crate::dot::{IRIS_CODE_LENGTH, ROTATIONS};
use cudarc::{
cublas::CudaBlas,
driver::{
Expand Down Expand Up @@ -111,33 +112,33 @@ impl DeviceManager {
&self,
preprocessed_query: &[Vec<u8>],
streams: &[CudaStream],
query_batch_size: usize,
) -> eyre::Result<CudaVec2DSlicerU8> {
let mut slices0 = vec![];
let mut slices1 = vec![];
let query_size = query_batch_size * ROTATIONS * IRIS_CODE_LENGTH;
for idx in 0..self.device_count() {
let device = self.device(idx);
device.bind_to_thread().unwrap();

let query0 =
unsafe { malloc_async(streams[idx].stream, preprocessed_query[0].len()).unwrap() };
let query0 = unsafe { malloc_async(streams[idx].stream, query_size).unwrap() };

let slice0 = StreamAwareCudaSlice::<u8>::upgrade_ptr_stream(
query0,
streams[idx].stream,
preprocessed_query[0].len(),
query_size,
);

unsafe {
memcpy_htod_async(query0, &preprocessed_query[0], streams[idx].stream).unwrap();
}

let query1 =
unsafe { malloc_async(streams[idx].stream, preprocessed_query[1].len()).unwrap() };
let query1 = unsafe { malloc_async(streams[idx].stream, query_size).unwrap() };

let slice1 = StreamAwareCudaSlice::<u8>::upgrade_ptr_stream(
query1,
streams[idx].stream,
preprocessed_query[1].len(),
query_size,
);

unsafe {
Expand Down
17 changes: 13 additions & 4 deletions iris-mpc-gpu/src/helpers/query_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,21 @@ impl CompactQuery {
&self,
device: &DeviceManager,
streams: &[CudaStream],
query_size: usize,
) -> eyre::Result<DeviceCompactQuery> {
Ok(DeviceCompactQuery {
code_query: device.htod_transfer_query(&self.code_query, streams)?,
mask_query: device.htod_transfer_query(&self.mask_query, streams)?,
code_query_insert: device.htod_transfer_query(&self.code_query_insert, streams)?,
mask_query_insert: device.htod_transfer_query(&self.mask_query_insert, streams)?,
code_query: device.htod_transfer_query(&self.code_query, streams, query_size)?,
mask_query: device.htod_transfer_query(&self.mask_query, streams, query_size)?,
code_query_insert: device.htod_transfer_query(
&self.code_query_insert,
streams,
query_size,
)?,
mask_query_insert: device.htod_transfer_query(
&self.mask_query_insert,
streams,
query_size,
)?,
})
}
}
Expand Down
Loading

0 comments on commit a09e816

Please sign in to comment.