Skip to content

Commit

Permalink
aurel/parallel-search: Introduce parallelism with sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
Aurélien Nicolas committed Jan 21, 2025
1 parent 26745a3 commit 453abc0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 26 deletions.
68 changes: 45 additions & 23 deletions iris-mpc-cpu/src/execution/hawk_main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use iris_mpc_common::iris_db::db::IrisDB;
use itertools::{izip, Itertools};
use rand::{thread_rng, Rng, SeedableRng};
use std::{collections::HashMap, sync::Arc, time::Duration};
use tokio::{task::JoinSet, time::sleep};
use tokio::{sync::RwLock, task::JoinSet, time::sleep};
use tonic::transport::Server;

const TEST_WAIT: Duration = Duration::from_secs(3);
Expand Down Expand Up @@ -58,6 +58,8 @@ pub struct HawkSession {
shared_rng: AesRng,
}

type HawkSessionRef = Arc<RwLock<HawkSession>>;

/// HawkRequest contains a batch of items to search.
pub struct HawkRequest {
pub my_iris_shares: Vec<GaloisRingSharedIris>,
Expand Down Expand Up @@ -184,56 +186,74 @@ impl HawkActor {
})
}

// TODO: Implement actual parallelism.
pub async fn search_to_insert(
&mut self,
sessions: &mut [HawkSession],
&self,
sessions: &[HawkSessionRef],
req: HawkRequest,
) -> Result<Vec<InsertPlan>> {
let mut plans = vec![];
for (i, iris) in req.my_iris_shares.into_iter().enumerate() {
let session = &mut sessions[i % sessions.len()];
plans.push(self.search_to_insert_one(session, iris).await?);

for chunk in req.my_iris_shares.chunks(sessions.len()) {
let tasks = izip!(chunk, sessions).map(|(iris, session)| {
let search_params = self.search_params.clone();
let graph_store = self.graph_store.clone();
let session = session.clone();
let iris = iris.clone();

let task = {
async move {
let mut session = session.write().await;
Self::search_to_insert_one(&search_params, &graph_store, &mut session, iris)
.await
}
};
tokio::spawn(task)
});

for t in tasks {
let plan = t.await?;
plans.push(plan);
}
}

Ok(plans)
}

// TODO: Remove `&mut self` requirement to support parallel sessions.
async fn search_to_insert_one(
&mut self,
search_params: &HnswSearcher,
graph_store: &GraphMem<Aby3Store>,
session: &mut HawkSession,
iris: GaloisRingSharedIris,
) -> Result<InsertPlan> {
let insertion_layer = self.search_params.select_layer(&mut session.shared_rng);
) -> InsertPlan {
let insertion_layer = search_params.select_layer(&mut session.shared_rng);
let query = session.aby3_store.prepare_query(iris);

let (links, set_ep) = self
.search_params
let (links, set_ep) = search_params
.search_to_insert(
&mut session.aby3_store,
&mut self.graph_store,
graph_store,
&query,
insertion_layer,
)
.await;

Ok(InsertPlan {
InsertPlan {
query,
links,
set_ep,
})
}
}

// TODO: Implement actual parallelism.
pub async fn insert(
&mut self,
sessions: &mut [HawkSession],
sessions: &[HawkSessionRef],
plans: Vec<InsertPlan>,
) -> Result<()> {
let plans = join_plans(plans);
for (i, plan) in izip!(0.., plans) {
let session = &mut sessions[i % sessions.len()];
self.insert_one(session, plan).await?;
for plan in plans {
let mut session = sessions[0].write().await;
self.insert_one(&mut session, plan).await?;
}
Ok(())
}
Expand Down Expand Up @@ -301,7 +321,8 @@ pub async fn hawk_main(args: HawkArgs) -> Result<()> {

let mut sessions = vec![];
for _ in 0..parallelism {
sessions.push(hawk_actor.new_session().await?);
let session = hawk_actor.new_session().await?;
sessions.push(Arc::new(RwLock::new(session)));
}

let my_iris_shares = IrisDB::new_random_rng(batch_size, iris_rng)
Expand All @@ -311,8 +332,9 @@ pub async fn hawk_main(args: HawkArgs) -> Result<()> {
.collect_vec();
let req = HawkRequest { my_iris_shares };

let plans = hawk_actor.search_to_insert(&mut sessions, req).await?;
hawk_actor.insert(&mut sessions, plans).await?;
let plans = hawk_actor.search_to_insert(&sessions, req).await?;

hawk_actor.insert(&sessions, plans).await?;

println!("🎉 Inserted {batch_size} items into the database");
Ok(())
Expand Down
6 changes: 3 additions & 3 deletions iris-mpc-cpu/src/hnsw/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ impl HnswSearcher {
async fn search_init<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
graph_store: &mut G,
graph_store: &G,
query: &V::QueryRef,
) -> (FurthestQueueV<V>, usize) {
if let Some((entry_point, layer)) = graph_store.get_entry_point().await {
Expand All @@ -246,7 +246,7 @@ impl HnswSearcher {
async fn search_layer<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
graph_store: &mut G,
graph_store: &G,
q: &V::QueryRef,
W: &mut FurthestQueueV<V>,
ef: usize,
Expand Down Expand Up @@ -384,7 +384,7 @@ impl HnswSearcher {
pub async fn search_to_insert<V: VectorStore, G: GraphStore<V>>(
&self,
vector_store: &mut V,
graph_store: &mut G,
graph_store: &G,
query: &V::QueryRef,
insertion_layer: usize,
) -> (Vec<FurthestQueueV<V>>, bool) {
Expand Down

0 comments on commit 453abc0

Please sign in to comment.