Skip to content

Commit

Permalink
prefetch the next chunk to gpu mem manually (#824)
Browse files Browse the repository at this point in the history
* prefetch the next chunk to gpu mem manually
  • Loading branch information
philsippl authored Dec 18, 2024
1 parent 29bee9d commit 4f7ba95
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 36 deletions.
84 changes: 70 additions & 14 deletions iris-mpc-gpu/src/dot/share_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use cudarc::{
CudaBlas,
},
driver::{
result::{self, malloc_async, malloc_managed},
sys::{CUdeviceptr, CUmemAttach_flags},
result::{self, malloc_async},
sys::CUdeviceptr,
CudaFunction, CudaSlice, CudaStream, CudaView, DevicePtr, DeviceSlice, LaunchAsync,
},
nccl,
Expand Down Expand Up @@ -114,6 +114,12 @@ pub struct SlicedProcessedDatabase {
pub code_sums_gr: CudaVec2DSlicerU32,
}

#[derive(Clone)]
pub struct DBChunkBuffers {
pub limb_0: Vec<CudaSlice<u8>>,
pub limb_1: Vec<CudaSlice<u8>>,
}

pub struct ShareDB {
peer_id: usize,
is_remote: bool,
Expand Down Expand Up @@ -237,22 +243,17 @@ impl ShareDB {
.devices()
.iter()
.map(|device| unsafe {
let mut host_mem0: *mut c_void = std::ptr::null_mut();
let mut host_mem1: *mut c_void = std::ptr::null_mut();
let _ = cudarc::driver::sys::lib()
.cuMemAllocHost_v2(&mut host_mem0, max_size * self.code_length);
let _ = cudarc::driver::sys::lib()
.cuMemAllocHost_v2(&mut host_mem1, max_size * self.code_length);
(
StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()),
(
StreamAwareCudaSlice::from(device.alloc(max_size).unwrap()),
(
malloc_managed(
max_size * self.code_length,
CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL,
)
.unwrap(),
malloc_managed(
max_size * self.code_length,
CUmemAttach_flags::CU_MEM_ATTACH_GLOBAL,
)
.unwrap(),
),
(host_mem0 as u64, host_mem1 as u64),
),
)
})
Expand Down Expand Up @@ -450,6 +451,61 @@ impl ShareDB {
}
}

pub fn alloc_db_chunk_buffer(&self, max_chunk_size: usize) -> DBChunkBuffers {
let mut limb_0 = vec![];
let mut limb_1 = vec![];
for device in self.device_manager.devices() {
unsafe {
limb_0.push(device.alloc(max_chunk_size * self.code_length).unwrap());
limb_1.push(device.alloc(max_chunk_size * self.code_length).unwrap());
}
}
DBChunkBuffers { limb_0, limb_1 }
}

pub fn prefetch_db_chunk(
&self,
db: &SlicedProcessedDatabase,
buffers: &DBChunkBuffers,
chunk_sizes: &[usize],
offset: &[usize],
db_sizes: &[usize],
streams: &[CudaStream],
) {
for idx in 0..self.device_manager.device_count() {
let device = self.device_manager.device(idx);
device.bind_to_thread().unwrap();

if offset[idx] >= db_sizes[idx] || offset[idx] + chunk_sizes[idx] > db_sizes[idx] {
continue;
}

unsafe {
cudarc::driver::sys::lib()
.cuMemcpyHtoDAsync_v2(
*buffers.limb_0[idx].device_ptr(),
(db.code_gr.limb_0[idx] as usize + offset[idx] * self.code_length)
as *mut _,
chunk_sizes[idx] * self.code_length,
streams[idx].stream,
)
.result()
.unwrap();

cudarc::driver::sys::lib()
.cuMemcpyHtoDAsync_v2(
*buffers.limb_1[idx].device_ptr(),
(db.code_gr.limb_1[idx] as usize + offset[idx] * self.code_length)
as *mut _,
chunk_sizes[idx] * self.code_length,
streams[idx].stream,
)
.result()
.unwrap();
}
}
}

pub fn dot<T>(
&mut self,
queries: &CudaVec2DSlicer<T>,
Expand Down
28 changes: 27 additions & 1 deletion iris-mpc-gpu/src/helpers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::threshold_ring::protocol::ChunkShare;
use cudarc::driver::{
result::{self, memcpy_dtoh_async, memcpy_htod_async, stream},
sys::{CUdeviceptr, CUstream, CUstream_st},
sys::{lib, CUdeviceptr, CUstream, CUstream_st},
CudaDevice, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError,
LaunchConfig,
};
Expand Down Expand Up @@ -104,6 +104,32 @@ pub unsafe fn dtod_at_offset(
}
}

/// Copy a slice from device to host with respective offsets.
/// # Safety
///
/// The caller must ensure that the `dst` and `src` pointers are valid
/// with the respective offsets
pub unsafe fn dtoh_at_offset(
dst: u64,
dst_offset: usize,
src: CUdeviceptr,
src_offset: usize,
len: usize,
stream_ptr: CUstream,
) {
unsafe {
lib()
.cuMemcpyDtoHAsync_v2(
(dst + dst_offset as u64) as *mut _,
(src + src_offset as u64) as CUdeviceptr,
len,
stream_ptr,
)
.result()
.unwrap();
}
}

pub fn dtoh_on_stream_sync<T: Default + Clone, U: DevicePtr<T>>(
input: &U,
device: &Arc<CudaDevice>,
Expand Down
19 changes: 14 additions & 5 deletions iris-mpc-gpu/src/helpers/query_processor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
dot::{
share_db::{ShareDB, SlicedProcessedDatabase},
share_db::{DBChunkBuffers, ShareDB, SlicedProcessedDatabase},
IRIS_CODE_LENGTH, MASK_CODE_LENGTH,
},
helpers::device_manager::DeviceManager,
Expand Down Expand Up @@ -82,6 +82,15 @@ impl<T> From<&CudaVec2DSlicer<T>> for CudaVec2DSlicerRawPointer {
}
}

impl From<&DBChunkBuffers> for CudaVec2DSlicerRawPointer {
fn from(buffers: &DBChunkBuffers) -> Self {
CudaVec2DSlicerRawPointer {
limb_0: buffers.limb_0.iter().map(|s| *s.device_ptr()).collect(),
limb_1: buffers.limb_1.iter().map(|s| *s.device_ptr()).collect(),
}
}
}

pub struct CudaVec2DSlicer<T> {
pub limb_0: Vec<StreamAwareCudaSlice<T>>,
pub limb_1: Vec<StreamAwareCudaSlice<T>>,
Expand Down Expand Up @@ -193,24 +202,24 @@ impl DeviceCompactQuery {
&self,
code_engine: &mut ShareDB,
mask_engine: &mut ShareDB,
sliced_code_db: &SlicedProcessedDatabase,
sliced_mask_db: &SlicedProcessedDatabase,
sliced_code_db: &CudaVec2DSlicerRawPointer,
sliced_mask_db: &CudaVec2DSlicerRawPointer,
database_sizes: &[usize],
offset: usize,
streams: &[CudaStream],
blass: &[CudaBlas],
) {
code_engine.dot(
&self.code_query,
&sliced_code_db.code_gr,
sliced_code_db,
database_sizes,
offset,
streams,
blass,
);
mask_engine.dot(
&self.mask_query,
&sliced_mask_db.code_gr,
sliced_mask_db,
database_sizes,
offset,
streams,
Expand Down
77 changes: 61 additions & 16 deletions iris-mpc-gpu/src/server/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ use super::{BatchQuery, Eye, ServerJob, ServerJobResult};
use crate::{
dot::{
distance_comparator::DistanceComparator,
share_db::{preprocess_query, ShareDB, SlicedProcessedDatabase},
share_db::{preprocess_query, DBChunkBuffers, ShareDB, SlicedProcessedDatabase},
IRIS_CODE_LENGTH, MASK_CODE_LENGTH, ROTATIONS,
},
helpers::{
self,
comm::NcclComm,
device_manager::DeviceManager,
query_processor::{CompactQuery, DeviceCompactQuery, DeviceCompactSums},
query_processor::{
CompactQuery, CudaVec2DSlicerRawPointer, DeviceCompactQuery, DeviceCompactSums,
},
},
threshold_ring::protocol::{ChunkShare, Circuits},
};
Expand Down Expand Up @@ -103,6 +105,8 @@ pub struct ServerActor {
max_db_size: usize,
return_partial_results: bool,
disable_persistence: bool,
code_chunk_buffers: Vec<DBChunkBuffers>,
mask_chunk_buffers: Vec<DBChunkBuffers>,
}

const NON_MATCH_ID: u32 = u32::MAX;
Expand Down Expand Up @@ -317,9 +321,11 @@ impl ServerActor {
let batch_match_list_right = distance_comparator.prepare_db_match_list(n_queries);

let query_db_size = vec![n_queries; device_manager.device_count()];

let current_db_sizes = vec![0; device_manager.device_count()];

let code_chunk_buffers = vec![codes_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];
let mask_chunk_buffers = vec![masks_engine.alloc_db_chunk_buffer(DB_CHUNK_SIZE); 2];

for dev in device_manager.devices() {
dev.synchronize().unwrap();
}
Expand Down Expand Up @@ -355,6 +361,8 @@ impl ServerActor {
max_db_size,
return_partial_results,
disable_persistence,
code_chunk_buffers,
mask_chunk_buffers,
})
}

Expand Down Expand Up @@ -1111,21 +1119,43 @@ impl ServerActor {
let mut current_phase2_event = self.device_manager.create_events();
let mut next_phase2_event = self.device_manager.create_events();

let chunk_sizes = |chunk_idx: usize| {
self.current_db_sizes
.iter()
.map(|s| (s - DB_CHUNK_SIZE * chunk_idx).clamp(1, DB_CHUNK_SIZE))
.collect::<Vec<_>>()
};

self.codes_engine.prefetch_db_chunk(
code_db_slices,
&self.code_chunk_buffers[0],
&chunk_sizes(0),
&vec![0; self.device_manager.device_count()],
&self.current_db_sizes,
&self.streams[0],
);
self.masks_engine.prefetch_db_chunk(
mask_db_slices,
&self.mask_chunk_buffers[0],
&chunk_sizes(0),
&vec![0; self.device_manager.device_count()],
&self.current_db_sizes,
&self.streams[0],
);

// ---- START DATABASE DEDUP ----
tracing::info!(party_id = self.party_id, "Start DB deduplication");
let ignore_device_results: Vec<bool> =
self.current_db_sizes.iter().map(|&s| s == 0).collect();
let mut db_chunk_idx = 0;
loop {
let request_streams = &self.streams[db_chunk_idx % 2];
let next_request_streams = &self.streams[(db_chunk_idx + 1) % 2];
let request_cublas_handles = &self.cublas_handles[db_chunk_idx % 2];

let offset = db_chunk_idx * DB_CHUNK_SIZE;
let chunk_size = self
.current_db_sizes
.iter()
.map(|s| (s - DB_CHUNK_SIZE * db_chunk_idx).clamp(1, DB_CHUNK_SIZE))
.collect::<Vec<_>>();
let chunk_size = chunk_sizes(db_chunk_idx);
let next_chunk_size = chunk_sizes(db_chunk_idx + 1);

// We need to pad the chunk size for two reasons:
// 1. Chunk size needs to be a multiple of 4, because the underlying
Expand All @@ -1149,6 +1179,24 @@ impl ServerActor {
.record_event(request_streams, &current_phase2_event);
}

// Prefetch next chunk
self.codes_engine.prefetch_db_chunk(
code_db_slices,
&self.code_chunk_buffers[(db_chunk_idx + 1) % 2],
&next_chunk_size,
&chunk_size.iter().map(|s| offset + s).collect::<Vec<_>>(),
&self.current_db_sizes,
next_request_streams,
);
self.masks_engine.prefetch_db_chunk(
mask_db_slices,
&self.mask_chunk_buffers[(db_chunk_idx + 1) % 2],
&next_chunk_size,
&chunk_size.iter().map(|s| offset + s).collect::<Vec<_>>(),
&self.current_db_sizes,
next_request_streams,
);

self.device_manager
.await_event(request_streams, &current_dot_event);

Expand All @@ -1157,10 +1205,10 @@ impl ServerActor {
compact_device_queries.dot_products_against_db(
&mut self.codes_engine,
&mut self.masks_engine,
code_db_slices,
mask_db_slices,
&CudaVec2DSlicerRawPointer::from(&self.code_chunk_buffers[db_chunk_idx % 2]),
&CudaVec2DSlicerRawPointer::from(&self.mask_chunk_buffers[db_chunk_idx % 2]),
&dot_chunk_size,
offset,
0,
request_streams,
request_cublas_handles,
);
Expand Down Expand Up @@ -1191,9 +1239,6 @@ impl ServerActor {
self.device_manager
.record_event(request_streams, &next_dot_event);

self.device_manager
.await_event(request_streams, &next_dot_event);

record_stream_time!(
&self.device_manager,
request_streams,
Expand Down Expand Up @@ -1621,7 +1666,7 @@ fn write_db_at_index(
),
] {
unsafe {
helpers::dtod_at_offset(
helpers::dtoh_at_offset(
db.code_gr.limb_0[device_index],
dst_index * code_length,
*query.limb_0[device_index].device_ptr(),
Expand All @@ -1630,7 +1675,7 @@ fn write_db_at_index(
streams[device_index].stream,
);

helpers::dtod_at_offset(
helpers::dtoh_at_offset(
db.code_gr.limb_1[device_index],
dst_index * code_length,
*query.limb_1[device_index].device_ptr(),
Expand Down

0 comments on commit 4f7ba95

Please sign in to comment.