Skip to content

Commit

Permalink
page lock in chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
eaypek-tfh committed Jan 24, 2025
1 parent dfb6221 commit 5ee9a90
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 37 deletions.
7 changes: 7 additions & 0 deletions iris-mpc-common/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ pub struct Config {

#[serde(default)]
pub page_lock_at_beginning: bool,

/// Percentage of the chunk size to page lock at each iteration
/// Must be a positive integer between [1-100]
/// The first memory chunk will be page-locked before starting db & s3
/// importers
#[serde(default)]
pub page_lock_chunk_percentage: usize,
}

fn default_load_chunks_parallelism() -> usize {
Expand Down
13 changes: 7 additions & 6 deletions iris-mpc-gpu/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,23 @@ pub fn htod_on_stream_sync<T: DeviceRepr>(
pub fn register_host_memory(
device_manager: Arc<DeviceManager>,
db: &CudaVec2DSlicerRawPointer,
max_db_length: usize,
chunk_length: usize,
chunk_offset: usize,
code_length: usize,
) {
let max_size = max_db_length / device_manager.device_count();
let size = chunk_length / device_manager.device_count();
for (device_index, device) in device_manager.devices().iter().enumerate() {
device.bind_to_thread().unwrap();
unsafe {
let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2(
db.limb_0[device_index] as *mut _,
max_size * code_length,
(db.limb_0[device_index] + (chunk_offset * code_length) as u64) as *mut _,
size * code_length,
CU_MEMHOSTALLOC_PORTABLE,
);

let _ = cudarc::driver::sys::lib().cuMemHostRegister_v2(
db.limb_1[device_index] as *mut _,
max_size * code_length,
(db.limb_1[device_index] + (chunk_offset * code_length) as u64) as *mut _,
size * code_length,
CU_MEMHOSTALLOC_PORTABLE,
);
}
Expand Down
74 changes: 43 additions & 31 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use aws_smithy_runtime_api::client::dns::{DnsFuture, ResolveDns};
use axum::{response::IntoResponse, routing::get, Router};
use clap::Parser;
use eyre::{eyre, Context};
use futures::{future, stream::BoxStream, StreamExt};
use futures::{stream::BoxStream, StreamExt};
use hickory_resolver::{
config::{ResolverConfig, ResolverOpts},
TokioAsyncResolver,
Expand Down Expand Up @@ -1191,39 +1191,53 @@ async fn server_main(config: Config) -> eyre::Result<()> {
let db_chunks_s3_store =
S3Store::new(db_chunks_s3_client.clone(), db_chunks_bucket_name.clone());

tracing::info!("Page-lock host memory");
let dbs = [
(actor.left_code_db_slices.code_gr.clone(), IRIS_CODE_LENGTH),
(actor.right_code_db_slices.code_gr.clone(), IRIS_CODE_LENGTH),
(actor.left_mask_db_slices.code_gr.clone(), MASK_CODE_LENGTH),
(actor.right_mask_db_slices.code_gr.clone(), MASK_CODE_LENGTH),
];
let mut page_lock_handles = Vec::new();
for (db, code_length) in dbs {
let max_db_size = config.max_db_size;
// spawn_blocking moves its closure to a worker thread
let device_manager_clone = actor.device_manager.clone();
let handle = spawn_blocking(move || {
tokio::runtime::Handle::current().block_on(async {
tracing::info!("Page-lock host memory");
let dbs = [
(actor.left_code_db_slices.code_gr.clone(), IRIS_CODE_LENGTH),
(actor.right_code_db_slices.code_gr.clone(), IRIS_CODE_LENGTH),
(actor.left_mask_db_slices.code_gr.clone(), MASK_CODE_LENGTH),
(actor.right_mask_db_slices.code_gr.clone(), MASK_CODE_LENGTH),
];
let n_page_lock_iters = 100 / config.page_lock_chunk_percentage;
let page_lock_chunk_size = config.max_db_size / n_page_lock_iters;
let dbs_clone = dbs.clone();
let now = Instant::now();
for (db, code_length) in dbs_clone {
let device_manager_clone = actor.device_manager.clone();
register_host_memory(
device_manager_clone,
&db,
max_db_size,
page_lock_chunk_size,
0,
code_length,
);
});
page_lock_handles.push(handle);
}
let mut page_lock_handles = Some(page_lock_handles);
tokio::runtime::Handle::current().block_on(async {
let mut now = Instant::now();
if config.page_lock_at_beginning {
future::join_all(page_lock_handles.take().unwrap()).await;
tracing::info!("Page-locking took {:?}", now.elapsed());
}
now = Instant::now();
tracing::info!("First chunk page-locking took {:?}", now.elapsed());

let device_manager_clone = actor.device_manager.clone();

// prepare the handle for the rest of the page locks
let page_lock_handle = spawn_blocking(move || {
for i in 1..n_page_lock_iters {
let dbs_clone = dbs.clone();
let device_manager_clone = device_manager_clone.clone();
for (db, code_length) in dbs_clone {
let device_manager_clone = device_manager_clone.clone();
register_host_memory(
device_manager_clone,
&db,
page_lock_chunk_size,
i * page_lock_chunk_size,
code_length,
);
}
}
});

let mut load_summary_ts = Instant::now();
let mut time_waiting_for_stream = time::Duration::from_secs(0);
let mut time_loading_into_memory = time::Duration::from_secs(0);
let mut time_waiting_for_stream = Duration::from_secs(0);
let mut time_loading_into_memory = Duration::from_secs(0);
let mut record_counter = 0;
let mut all_serial_ids: HashSet<i64> =
HashSet::from_iter(1..=(store_len as i64));
Expand Down Expand Up @@ -1384,10 +1398,8 @@ async fn server_main(config: Config) -> eyre::Result<()> {
tracing::info!("Preprocessing db");
actor.preprocess_db();

if !config.page_lock_at_beginning {
tracing::info!("Waiting for page-lock to finish");
future::join_all(page_lock_handles.take().unwrap()).await;
}
tracing::info!("Waiting for all page-locks to finish");
page_lock_handle.await.expect("Error while page-locking");

tracing::info!(
"Loaded {} records from db into memory in {:?} [DB sizes: {:?}]",
Expand Down

0 comments on commit 5ee9a90

Please sign in to comment.