diff --git a/iris-mpc-gpu/src/dot/share_db.rs b/iris-mpc-gpu/src/dot/share_db.rs index bd0faf29c..c7e8eeb0f 100644 --- a/iris-mpc-gpu/src/dot/share_db.rs +++ b/iris-mpc-gpu/src/dot/share_db.rs @@ -20,8 +20,8 @@ use cudarc::{ CudaBlas, }, driver::{ - result::{self, malloc_async, malloc_managed}, - sys::{CUdeviceptr, CUmemAttach_flags}, + result::{self, malloc_async}, + sys::CUdeviceptr, CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceSlice, LaunchAsync, }, nccl, @@ -114,6 +114,12 @@ pub struct SlicedProcessedDatabase { pub code_sums_gr: CudaVec2DSlicerU32, } +#[derive(Clone)] +pub struct DBChunkBuffers { + pub limb_0: Vec>, + pub limb_1: Vec>, +} + pub struct ShareDB { peer_id: usize, is_remote: bool, @@ -237,22 +243,17 @@ impl ShareDB { .devices() .iter() .map(|device| unsafe { + let mut host_mem0: *mut c_void = std::ptr::null_mut(); + let mut host_mem1: *mut c_void = std::ptr::null_mut(); + let _ = cudarc::driver::sys::lib() + .cuMemAllocHost_v2(&mut host_mem0, max_size * self.code_length); + let _ = cudarc::driver::sys::lib() + .cuMemAllocHost_v2(&mut host_mem1, max_size * self.code_length); ( StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), ( StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()), - ( - malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(), - malloc_managed( - max_size * self.code_length, - CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL, - ) - .unwrap(), - ), + (host_mem0 as u64, host_mem1 as u64), ), ) }) @@ -450,6 +451,61 @@ impl ShareDB { } } + pub fn alloc_db_chunk_buffer(&self, max_chunk_size: usize) -> DBChunkBuffers { + let mut limb_0 = vec![]; + let mut limb_1 = vec![]; + for device in self.device_manager.devices() { + unsafe { + limb_0.push(device.alloc(max_chunk_size * self.code_length).unwrap()); + limb_1.push(device.alloc(max_chunk_size * self.code_length).unwrap()); + } + } + DBChunkBuffers { limb_0, limb_1 } + } + + pub fn prefetch_db_chunk( + &self, + db: &SlicedProcessedDatabase, + buffers: &DBChunkBuffers, + chunk_sizes: &[usize], + offset: &[usize], + db_sizes: &[usize], + streams: &[CudaStream], + ) { + for idx in 0..self.device_manager.device_count() { + let device = self.device_manager.device(idx); + device.bind_to_thread().unwrap(); + + if offset[idx] >= db_sizes[idx] || offset[idx] + chunk_sizes[idx] > db_sizes[idx] { + continue; + } + + unsafe { + cudarc::driver::sys::lib() + .cuMemcpyHtoDAsync_v2( + *buffers.limb_0[idx].device_ptr(), + (db.code_gr.limb_0[idx] as usize + offset[idx] * self.code_length) + as *mut _, + chunk_sizes[idx] * self.code_length, + streams[idx].stream, + ) + .result() + .unwrap(); + + cudarc::driver::sys::lib() + .cuMemcpyHtoDAsync_v2( + *buffers.limb_1[idx].device_ptr(), + (db.code_gr.limb_1[idx] as usize + offset[idx] * self.code_length) + as *mut _, + chunk_sizes[idx] * self.code_length, + streams[idx].stream, + ) + .result() + .unwrap(); + } + } + } + pub fn dot( &mut self, queries: &CudaVec2DSlicer, diff --git a/iris-mpc-gpu/src/helpers/mod.rs b/iris-mpc-gpu/src/helpers/mod.rs index 149fc10cf..e4bd114ee 100644 --- a/iris-mpc-gpu/src/helpers/mod.rs +++ b/iris-mpc-gpu/src/helpers/mod.rs @@ -1,7 +1,7 @@ use crate::threshold_ring::protocol::ChunkShare; use cudarc::driver::{ result::{self, memcpy_dtoh_async, memcpy_htod_async, stream}, - sys::{CUdeviceptr, CUstream, CUstream_st}, + sys::{lib, CUdeviceptr, CUstream, CUstream_st}, CudaDevice, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError, LaunchConfig, }; @@ -104,6 +104,32 @@ pub unsafe fn dtod_at_offset( } } +/// Copy a slice from device to host with respective offsets. +/// # Safety +/// +/// The caller must ensure that the `dst` and `src` pointers are valid +/// with the respective offsets +pub unsafe fn dtoh_at_offset( + dst: u64, + dst_offset: usize, + src: CUdeviceptr, + src_offset: usize, + len: usize, + stream_ptr: CUstream, +) { + unsafe { + lib() + .cuMemcpyDtoHAsync_v2( + (dst + dst_offset as u64) as *mut _, + (src + src_offset as u64) as CUdeviceptr, + len, + stream_ptr, + ) + .result() + .unwrap(); + } +} + pub fn dtoh_on_stream_sync>( input: &U, device: &Arc, diff --git a/iris-mpc-gpu/src/helpers/query_processor.rs b/iris-mpc-gpu/src/helpers/query_processor.rs index a02a3b4bd..de27e4e5d 100644 --- a/iris-mpc-gpu/src/helpers/query_processor.rs +++ b/iris-mpc-gpu/src/helpers/query_processor.rs @@ -1,6 +1,6 @@ use crate::{ dot::{ - share_db::{ShareDB, SlicedProcessedDatabase}, + share_db::{DBChunkBuffers, ShareDB, SlicedProcessedDatabase}, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }, helpers::device_manager::DeviceManager, @@ -82,6 +82,15 @@ impl From<&CudaVec2DSlicer> for CudaVec2DSlicerRawPointer { } } +impl From<&DBChunkBuffers> for CudaVec2DSlicerRawPointer { + fn from(buffers: &DBChunkBuffers) -> Self { + CudaVec2DSlicerRawPointer { + limb_0: buffers.limb_0.iter().map(|s| *s.device_ptr()).collect(), + limb_1: buffers.limb_1.iter().map(|s| *s.device_ptr()).collect(), + } + } +} + pub struct CudaVec2DSlicer { pub limb_0: Vec>, pub limb_1: Vec>, @@ -193,8 +202,8 @@ impl DeviceCompactQuery { &self, code_engine: &mut ShareDB, mask_engine: &mut ShareDB, - sliced_code_db: &SlicedProcessedDatabase, - sliced_mask_db: &SlicedProcessedDatabase, + sliced_code_db: &CudaVec2DSlicerRawPointer, + sliced_mask_db: &CudaVec2DSlicerRawPointer, database_sizes: &[usize], offset: usize, streams: &[CudaStream], @@ -202,7 +211,7 @@ impl DeviceCompactQuery { ) { code_engine.dot( &self.code_query, - &sliced_code_db.code_gr, + sliced_code_db, database_sizes, offset, streams, @@ -210,7 +219,7 @@ impl DeviceCompactQuery { ); mask_engine.dot( &self.mask_query, - &sliced_mask_db.code_gr, + sliced_mask_db, database_sizes, offset, streams, diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 4d2c86524..95c5aa289 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -2,14 +2,16 @@ use super::{BatchQuery, Eye, ServerJob, ServerJobResult}; use crate::{ dot::{ distance_comparator::DistanceComparator, - share_db::{preprocess_query, ShareDB, SlicedProcessedDatabase}, + share_db::{preprocess_query, DBChunkBuffers, ShareDB, SlicedProcessedDatabase}, IRIS_CODE_LENGTH, MASK_CODE_LENGTH, ROTATIONS, }, helpers::{ self, comm::NcclComm, device_manager::DeviceManager, - query_processor::{CompactQuery, DeviceCompactQuery, DeviceCompactSums}, + query_processor::{ + CompactQuery, CudaVec2DSlicerRawPointer, DeviceCompactQuery, DeviceCompactSums, + }, }, threshold_ring::protocol::{ChunkShare, Circuits}, }; @@ -103,6 +105,8 @@ pub struct ServerActor { max_db_size: usize, return_partial_results: bool, disable_persistence: bool, + code_chunk_buffers: Vec, + mask_chunk_buffers: Vec, } const NON_MATCH_ID: u32 = u32::MAX; @@ -317,9 +321,11 @@ impl ServerActor { let batch_match_list_right = distance_comparator.prepare_db_match_list(n_queries); let query_db_size = vec![n_queries; device_manager.device_count()]; - let current_db_sizes = vec![0; device_manager.device_count()]; + let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2]; + for dev in device_manager.devices() { dev.synchronize().unwrap(); } @@ -355,6 +361,8 @@ impl ServerActor { max_db_size, return_partial_results, disable_persistence, + code_chunk_buffers, + mask_chunk_buffers, }) } @@ -1111,6 +1119,30 @@ impl ServerActor { let mut current_phase2_event = self.device_manager.create_events(); let mut next_phase2_event = self.device_manager.create_events(); + let chunk_sizes = |chunk_idx: usize| { + self.current_db_sizes + .iter() + .map(|s| (s - DB_CHUNK_SIZE * chunk_idx).clamp(1, DB_CHUNK_SIZE)) + .collect::>() + }; + + self.codes_engine.prefetch_db_chunk( + code_db_slices, + &self.code_chunk_buffers[0], + &chunk_sizes(0), + &vec![0; self.device_manager.device_count()], + &self.current_db_sizes, + &self.streams[0], + ); + self.masks_engine.prefetch_db_chunk( + mask_db_slices, + &self.mask_chunk_buffers[0], + &chunk_sizes(0), + &vec![0; self.device_manager.device_count()], + &self.current_db_sizes, + &self.streams[0], + ); + // ---- START DATABASE DEDUP ---- tracing::info!(party_id = self.party_id, "Start DB deduplication"); let ignore_device_results: Vec = @@ -1118,14 +1150,12 @@ impl ServerActor { let mut db_chunk_idx = 0; loop { let request_streams = &self.streams[db_chunk_idx % 2]; + let next_request_streams = &self.streams[(db_chunk_idx + 1) % 2]; let request_cublas_handles = &self.cublas_handles[db_chunk_idx % 2]; let offset = db_chunk_idx * DB_CHUNK_SIZE; - let chunk_size = self - .current_db_sizes - .iter() - .map(|s| (s - DB_CHUNK_SIZE * db_chunk_idx).clamp(1, DB_CHUNK_SIZE)) - .collect::>(); + let chunk_size = chunk_sizes(db_chunk_idx); + let next_chunk_size = chunk_sizes(db_chunk_idx + 1); // We need to pad the chunk size for two reasons: // 1. Chunk size needs to be a multiple of 4, because the underlying @@ -1149,6 +1179,24 @@ impl ServerActor { .record_event(request_streams, ¤t_phase2_event); } + // Prefetch next chunk + self.codes_engine.prefetch_db_chunk( + code_db_slices, + &self.code_chunk_buffers[(db_chunk_idx + 1) % 2], + &next_chunk_size, + &chunk_size.iter().map(|s| offset + s).collect::>(), + &self.current_db_sizes, + next_request_streams, + ); + self.masks_engine.prefetch_db_chunk( + mask_db_slices, + &self.mask_chunk_buffers[(db_chunk_idx + 1) % 2], + &next_chunk_size, + &chunk_size.iter().map(|s| offset + s).collect::>(), + &self.current_db_sizes, + next_request_streams, + ); + self.device_manager .await_event(request_streams, ¤t_dot_event); @@ -1157,10 +1205,10 @@ impl ServerActor { compact_device_queries.dot_products_against_db( &mut self.codes_engine, &mut self.masks_engine, - code_db_slices, - mask_db_slices, + &CudaVec2DSlicerRawPointer::from(&self.code_chunk_buffers[db_chunk_idx % 2]), + &CudaVec2DSlicerRawPointer::from(&self.mask_chunk_buffers[db_chunk_idx % 2]), &dot_chunk_size, - offset, + 0, request_streams, request_cublas_handles, ); @@ -1191,9 +1239,6 @@ impl ServerActor { self.device_manager .record_event(request_streams, &next_dot_event); - self.device_manager - .await_event(request_streams, &next_dot_event); - record_stream_time!( &self.device_manager, request_streams, @@ -1621,7 +1666,7 @@ fn write_db_at_index( ), ] { unsafe { - helpers::dtod_at_offset( + helpers::dtoh_at_offset( db.code_gr.limb_0[device_index], dst_index * code_length, *query.limb_0[device_index].device_ptr(), @@ -1630,7 +1675,7 @@ fn write_db_at_index( streams[device_index].stream, ); - helpers::dtod_at_offset( + helpers::dtoh_at_offset( db.code_gr.limb_1[device_index], dst_index * code_length, *query.limb_1[device_index].device_ptr(),