diff --git a/iris-mpc-gpu/src/helpers/mod.rs b/iris-mpc-gpu/src/helpers/mod.rs index e4bd114ee..a3076dce2 100644 --- a/iris-mpc-gpu/src/helpers/mod.rs +++ b/iris-mpc-gpu/src/helpers/mod.rs @@ -1,10 +1,12 @@ use crate::threshold_ring::protocol::ChunkShare; use cudarc::driver::{ result::{self, memcpy_dtoh_async, memcpy_htod_async, stream}, - sys::{lib, CUdeviceptr, CUstream, CUstream_st}, + sys::{lib, CUdeviceptr, CUstream, CUstream_st, CU_MEMHOSTALLOC_PORTABLE}, CudaDevice, CudaSlice, CudaStream, DevicePtr, DevicePtrMut, DeviceRepr, DriverError, LaunchConfig, }; +use device_manager::DeviceManager; +use query_processor::CudaVec2DSlicerRawPointer; use std::sync::Arc; pub mod comm; @@ -167,3 +169,28 @@ pub fn htod_on_stream_sync( }; Ok(buf) } + +pub fn register_host_memory( + device_manager: Arc, + db: &CudaVec2DSlicerRawPointer, + max_db_length: usize, + code_length: usize, +) { + let max_size = max_db_length / device_manager.device_count(); + for (device_index, device) in device_manager.devices().iter().enumerate() { + device.bind_to_thread().unwrap(); + unsafe { + let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2( + db.limb_0[device_index] as *mut _, + max_size * code_length, + CU_MEMHOSTALLOC_PORTABLE, + ); + + let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2( + db.limb_1[device_index] as *mut _, + max_size * code_length, + CU_MEMHOSTALLOC_PORTABLE, + ); + } + } +} diff --git a/iris-mpc-gpu/src/helpers/query_processor.rs b/iris-mpc-gpu/src/helpers/query_processor.rs index de27e4e5d..7d7f2ea50 100644 --- a/iris-mpc-gpu/src/helpers/query_processor.rs +++ b/iris-mpc-gpu/src/helpers/query_processor.rs @@ -68,6 +68,7 @@ impl Drop for StreamAwareCudaSlice { /// Holds the raw memory pointers for the 2D slices. /// Memory is not freed when the struct is dropped, but must be freed manually. +#[derive(Clone)] pub struct CudaVec2DSlicerRawPointer { pub limb_0: Vec, pub limb_1: Vec, diff --git a/iris-mpc-gpu/src/server/actor.rs b/iris-mpc-gpu/src/server/actor.rs index 6711c0884..07da72879 100644 --- a/iris-mpc-gpu/src/server/actor.rs +++ b/iris-mpc-gpu/src/server/actor.rs @@ -77,44 +77,44 @@ const KDF_SALT: &str = "111a1a93518f670e9bb0c2c68888e2beb9406d4c4ed571dc77b801e6 const SUPERMATCH_THRESHOLD: usize = 4_000; pub struct ServerActor { - job_queue: mpsc::Receiver, - device_manager: Arc, - party_id: usize, + job_queue: mpsc::Receiver, + pub device_manager: Arc, + party_id: usize, // engines - codes_engine: ShareDB, - masks_engine: ShareDB, - batch_codes_engine: ShareDB, - batch_masks_engine: ShareDB, - phase2: Circuits, - phase2_batch: Circuits, - distance_comparator: DistanceComparator, - comms: Vec>, + codes_engine: ShareDB, + masks_engine: ShareDB, + batch_codes_engine: ShareDB, + batch_masks_engine: ShareDB, + phase2: Circuits, + phase2_batch: Circuits, + distance_comparator: DistanceComparator, + comms: Vec>, // DB slices - left_code_db_slices: SlicedProcessedDatabase, - left_mask_db_slices: SlicedProcessedDatabase, - right_code_db_slices: SlicedProcessedDatabase, - right_mask_db_slices: SlicedProcessedDatabase, - streams: Vec>, - cublas_handles: Vec>, - results: Vec>, - batch_results: Vec>, - final_results: Vec>, - db_match_list_left: Vec>, - db_match_list_right: Vec>, - batch_match_list_left: Vec>, - batch_match_list_right: Vec>, - current_db_sizes: Vec, - query_db_size: Vec, - max_batch_size: usize, - max_db_size: usize, - return_partial_results: bool, - disable_persistence: bool, - enable_debug_timing: bool, - code_chunk_buffers: Vec, - mask_chunk_buffers: Vec, - dot_events: Vec>, - exchange_events: Vec>, - phase2_events: Vec>, + pub left_code_db_slices: SlicedProcessedDatabase, + pub left_mask_db_slices: SlicedProcessedDatabase, + pub right_code_db_slices: SlicedProcessedDatabase, + pub right_mask_db_slices: SlicedProcessedDatabase, + streams: Vec>, + cublas_handles: Vec>, + results: Vec>, + batch_results: Vec>, + final_results: Vec>, + db_match_list_left: Vec>, + db_match_list_right: Vec>, + batch_match_list_left: Vec>, + batch_match_list_right: Vec>, + current_db_sizes: Vec, + query_db_size: Vec, + max_batch_size: usize, + max_db_size: usize, + return_partial_results: bool, + disable_persistence: bool, + enable_debug_timing: bool, + code_chunk_buffers: Vec, + mask_chunk_buffers: Vec, + dot_events: Vec>, + exchange_events: Vec>, + phase2_events: Vec>, } const NON_MATCH_ID: u32 = u32::MAX; diff --git a/iris-mpc/src/bin/server.rs b/iris-mpc/src/bin/server.rs index adf7beddb..062e49358 100644 --- a/iris-mpc/src/bin/server.rs +++ b/iris-mpc/src/bin/server.rs @@ -31,9 +31,10 @@ use iris_mpc_common::{ sync::SyncState, task_monitor::TaskMonitor, }, + IRIS_CODE_LENGTH, MASK_CODE_LENGTH, }; use iris_mpc_gpu::{ - helpers::device_manager::DeviceManager, + helpers::{device_manager::DeviceManager, register_host_memory}, server::{ get_dummy_shares_for_deletion, sync_nccl, BatchMetadata, BatchQuery, BatchQueryEntriesPreprocessed, ServerActor, ServerJobResult, @@ -1039,6 +1040,33 @@ async fn server_main(config: Config) -> eyre::Result<()> { } }; + tracing::info!("Page-lock host memory"); + let left_codes = actor.left_code_db_slices.code_gr.clone(); + let right_codes = actor.right_code_db_slices.code_gr.clone(); + let left_masks = actor.left_mask_db_slices.code_gr.clone(); + let right_masks = actor.right_mask_db_slices.code_gr.clone(); + let device_manager_clone = actor.device_manager.clone(); + + let page_lock_handle = spawn_blocking(move || { + for db in [&left_codes, &right_codes] { + register_host_memory( + device_manager_clone.clone(), + db, + config.max_db_size, + IRIS_CODE_LENGTH, + ); + } + + for db in [&left_masks, &right_masks] { + register_host_memory( + device_manager_clone.clone(), + db, + config.max_db_size, + MASK_CODE_LENGTH, + ); + } + }); + let now = Instant::now(); let mut now_load_summary = Instant::now(); let mut time_waiting_for_stream = time::Duration::from_secs(0); @@ -1133,8 +1161,8 @@ async fn server_main(config: Config) -> eyre::Result<()> { tracing::info!("Preprocessing db"); actor.preprocess_db(); - tracing::info!("Page-lock host memory"); - actor.register_host_memory(); + tracing::info!("Waiting for page-lock to finish"); + page_lock_handle.await?; tracing::info!( "Loaded {} records from db into memory [DB sizes: {:?}]",