diff --git a/Cargo.lock b/Cargo.lock index 80faf4f91..c8a3933e7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2755,6 +2755,7 @@ dependencies = [ "num-traits", "prost", "rand", + "rand_distr", "rstest", "serde", "serde_json", diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 6940012ac..dcfdceb1e 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -26,6 +26,7 @@ itertools.workspace = true num-traits.workspace = true prost = "0.13" rand.workspace = true +rand_distr = "0.4.3" rstest = "0.23.0" serde.workspace = true serde_json.workspace = true @@ -57,4 +58,4 @@ path = "bin/local_hnsw.rs" [[bin]] name = "generate_benchmark_data" -path = "bin/generate_benchmark_data.rs" \ No newline at end of file +path = "bin/generate_benchmark_data.rs" diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index 00e45e9a4..b72d787d0 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -1,11 +1,12 @@ use aes_prng::AesRng; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode}; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; use iris_mpc_cpu::{ database_generators::{create_random_sharing, generate_galois_iris_shares}, execution::local::LocalRuntime, hawkers::{aby3_store::Aby3Store, plaintext_store::PlaintextStore}, + hnsw::searcher::HnswSearcher, protocol::ops::{ batch_signed_lift_vec, cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3, }, @@ -28,7 +29,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); for _ in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); @@ -44,7 +45,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || (vector.clone(), graph.clone()), |(mut db_vectors, mut graph)| async move { - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let query = db_vectors.prepare_query(on_the_fly_query); @@ -201,7 +202,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); @@ -235,7 +236,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); diff --git a/iris-mpc-cpu/examples/hnsw-ex.rs b/iris-mpc-cpu/examples/hnsw-ex.rs index 71c925028..1b179a560 100644 --- a/iris-mpc-cpu/examples/hnsw-ex.rs +++ b/iris-mpc-cpu/examples/hnsw-ex.rs @@ -1,7 +1,7 @@ use aes_prng::AesRng; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; -use iris_mpc_cpu::hawkers::plaintext_store::PlaintextStore; +use iris_mpc_cpu::{hawkers::plaintext_store::PlaintextStore, hnsw::searcher::HnswSearcher}; use rand::SeedableRng; const DATABASE_SIZE: usize = 1_000; @@ -16,7 +16,7 @@ fn main() { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); for idx in 0..DATABASE_SIZE { let raw_query = IrisCode::random_rng(&mut rng); diff --git a/iris-mpc-cpu/src/hawkers/aby3_store.rs b/iris-mpc-cpu/src/hawkers/aby3_store.rs index 23f02c459..b859a4f91 100644 --- a/iris-mpc-cpu/src/hawkers/aby3_store.rs +++ b/iris-mpc-cpu/src/hawkers/aby3_store.rs @@ -1,4 +1,3 @@ -use super::plaintext_store::PlaintextStore; use crate::{ database_generators::{generate_galois_iris_shares, GaloisRingSharedIris}, execution::{ @@ -6,7 +5,8 @@ use crate::{ player::Identity, session::{Session, SessionHandles}, }, - hawkers::plaintext_store::PointId, + hawkers::plaintext_store::{PlaintextStore, PointId}, + hnsw::HnswSearcher, network::NetworkType, protocol::ops::{ batch_signed_lift_vec, compare_threshold_and_open, cross_compare, @@ -22,7 +22,7 @@ use aes_prng::AesRng; use hawk_pack::{ data_structures::queue::FurthestQueue, graph_store::{graph_mem::Layer, GraphMem}, - GraphStore, HawkSearcher, VectorStore, + GraphStore, VectorStore, }; use iris_mpc_common::iris_db::db::IrisDB; use rand::{CryptoRng, RngCore, SeedableRng}; @@ -563,7 +563,7 @@ impl Aby3Store { .collect::>(); jobs.spawn(async move { let mut graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); // insert queries for query in queries.iter() { searcher @@ -603,9 +603,9 @@ impl Aby3Store { #[cfg(test)] mod tests { use super::*; - use crate::database_generators::generate_galois_iris_shares; + use crate::{database_generators::generate_galois_iris_shares, hnsw::HnswSearcher}; use aes_prng::AesRng; - use hawk_pack::{graph_store::GraphMem, HawkSearcher}; + use hawk_pack::graph_store::GraphMem; use itertools::Itertools; use rand::SeedableRng; use tracing_test::traced_test; @@ -634,7 +634,7 @@ mod tests { let mut rng = rng.clone(); jobs.spawn(async move { let mut aby3_graph = GraphMem::new(); - let db = HawkSearcher::default(); + let db = HnswSearcher::default(); let mut inserted = vec![]; // insert queries @@ -693,7 +693,7 @@ mod tests { premade_v.storage.body.read().unwrap().points ); } - let hawk_searcher = HawkSearcher::default(); + let hawk_searcher = HnswSearcher::default(); for i in 0..database_size { let cleartext_neighbors = hawk_searcher @@ -840,7 +840,7 @@ mod tests { async fn test_gr_scratch_hnsw() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let mut vectors_and_graphs = Aby3Store::shared_random_setup(&mut rng, database_size, NetworkType::LocalChannel) .await diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 80a462a82..7fcbcd10c 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,5 +1,6 @@ +use crate::hnsw::HnswSearcher; use aes_prng::AesRng; -use hawk_pack::{graph_store::GraphMem, HawkSearcher, VectorStore}; +use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ db::IrisDB, iris::{IrisCode, MATCH_THRESHOLD_RATIO}, @@ -148,7 +149,7 @@ impl PlaintextStore { let mut plaintext_vector_store = PlaintextStore::default(); let mut plaintext_graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); for raw_query in cleartext_database.iter() { let query = plaintext_vector_store.prepare_query(raw_query.clone()); @@ -189,7 +190,7 @@ impl PlaintextStore { let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; let mut plaintext_graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); for i in 0..graph_size { searcher @@ -209,8 +210,8 @@ impl PlaintextStore { #[cfg(test)] mod tests { use super::*; + use crate::hnsw::HnswSearcher; use aes_prng::AesRng; - use hawk_pack::HawkSearcher; use iris_mpc_common::iris_db::db::IrisDB; use rand::SeedableRng; use tracing_test::traced_test; @@ -292,7 +293,7 @@ mod tests { async fn test_plaintext_hnsw_matcher() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 1; - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let (mut ptxt_vector, mut ptxt_graph) = PlaintextStore::create_random(&mut rng, database_size) .await diff --git a/iris-mpc-cpu/src/hnsw/mod.rs b/iris-mpc-cpu/src/hnsw/mod.rs new file mode 100644 index 000000000..8cc4cecef --- /dev/null +++ b/iris-mpc-cpu/src/hnsw/mod.rs @@ -0,0 +1,3 @@ +pub mod searcher; + +pub use searcher::HnswSearcher; diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs new file mode 100644 index 000000000..1dd72a1b5 --- /dev/null +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -0,0 +1,515 @@ +//* Implementation of HNSW algorithm for k-nearest-neighbor search over iris +//* biometric templates with high-latency MPC comparison operations. Based on +//* the `HawkSearcher` class of the hawk-pack crate: +//* +//* https://github.com/Inversed-Tech/hawk-pack/ + +pub use hawk_pack::data_structures::queue::{ + FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, +}; +use hawk_pack::{GraphStore, VectorStore}; +use rand::RngCore; +use rand_distr::{Distribution, Geometric}; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +// Specify construction and search parameters by layer up to this value minus 1 +// any higher layers will use the last set of parameters +pub const N_PARAM_LAYERS: usize = 5; + +#[allow(non_snake_case)] +#[derive(PartialEq, Clone, Serialize, Deserialize)] +pub struct HnswParams { + pub M: [usize; N_PARAM_LAYERS], // number of neighbors for insertion + pub M_max: [usize; N_PARAM_LAYERS], // maximum number of neighbors + pub ef_constr_search: [usize; N_PARAM_LAYERS], // ef_constr for search layers + pub ef_constr_insert: [usize; N_PARAM_LAYERS], // ef_constr for insertion layers + pub ef_search: [usize; N_PARAM_LAYERS], // ef for search + pub layer_probability: f64, /* p for geometric distribution of layer + * densities */ +} + +#[allow(non_snake_case, clippy::too_many_arguments)] +impl HnswParams { + /// Construct a `Params` object corresponding to parameter configuration + /// providing the functionality described in the original HNSW paper: + /// - ef_construction exploration factor used for insertion layers + /// - ef_search exploration factor used for layer 0 in search + /// - higher layers in both insertion and search use exploration factor 1, + /// representing simple greedy search + /// - vertex degrees bounded by M_max = M in positive layer graphs + /// - vertex degrees bounded by M_max0 = 2*M in layer 0 graph + /// - m_L = 1 / ln(M) so that layer density decreases by a factor of M at + /// each successive hierarchical layer + pub fn new(ef_construction: usize, ef_search: usize, M: usize) -> Self { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [1usize; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef_construction; N_PARAM_LAYERS]; + let mut ef_search_arr = [1usize; N_PARAM_LAYERS]; + ef_search_arr[0] = ef_search; + let layer_probability = (M as f64).recip(); + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + + /// Same as standard constructor but with an extra input for a non-standard + /// `layer_probability` parameter. + pub fn new_with_layer_probability( + ef_construction: usize, + ef_search: usize, + M: usize, + layer_probability: f64, + ) -> Self { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [1usize; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef_construction; N_PARAM_LAYERS]; + let mut ef_search_arr = [1usize; N_PARAM_LAYERS]; + ef_search_arr[0] = ef_search; + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + + /// Parameter configuration using fixed exploration factor for all layer + /// search operations, both for insertion and for search. + pub fn new_uniform(ef: usize, M: usize) -> Self { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [ef; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef; N_PARAM_LAYERS]; + let ef_search_arr = [ef; N_PARAM_LAYERS]; + let layer_probability = (M as f64).recip(); + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + + /// Compute the parameter m_L associated with a geometric distribution + /// parameter q describing the random layer of newly inserted graph nodes. + /// + /// E.g. for graph hierarchy where each layer has a factor of 32 fewer + /// entries than the last, the `layer_probability` input is 1/32. + pub fn m_L_from_layer_probability(layer_probability: f64) -> f64 { + -layer_probability.ln().recip() + } + + /// Compute the parameter q for the geometric distribution used to select + /// the insertion layer for newly inserted graph nodes, from the parameter + /// m_L of the original HNSW paper. + pub fn layer_probability_from_m_L(m_L: f64) -> f64 { + (-m_L.recip()).exp() + } + + pub fn get_M(&self, lc: usize) -> usize { + Self::get_val(&self.M, lc) + } + + pub fn get_M_max(&self, lc: usize) -> usize { + Self::get_val(&self.M_max, lc) + } + + pub fn get_ef_constr_search(&self, lc: usize) -> usize { + Self::get_val(&self.ef_constr_search, lc) + } + + pub fn get_ef_constr_insert(&self, lc: usize) -> usize { + Self::get_val(&self.ef_constr_insert, lc) + } + + pub fn get_ef_search(&self, lc: usize) -> usize { + Self::get_val(&self.ef_search, lc) + } + + pub fn get_layer_probability(&self) -> f64 { + self.layer_probability + } + + pub fn get_m_L(&self) -> f64 { + Self::m_L_from_layer_probability(self.layer_probability) + } + + #[inline(always)] + /// Select value at index `lc` from the input fixed-size array, or the last + /// index of this array if `lc` is larger than the array size. + fn get_val(arr: &[usize; N_PARAM_LAYERS], lc: usize) -> usize { + arr[lc.min(N_PARAM_LAYERS - 1)] + } +} + +/// An implementation of the HNSW algorithm. +/// +/// Operations on vectors are delegated to a VectorStore. +/// Operations on the graph are delegated to a GraphStore. +#[derive(Clone, Serialize, Deserialize)] +pub struct HnswSearcher { + pub params: HnswParams, +} + +// TODO remove default value; this varies too much between applications +// to make sense to specify something "obvious" +impl Default for HnswSearcher { + fn default() -> Self { + HnswSearcher { + params: HnswParams::new(64, 32, 32), + } + } +} + +#[allow(non_snake_case)] +impl HnswSearcher { + async fn connect_bidir>( + &self, + vector_store: &mut V, + graph_store: &mut G, + q: &V::VectorRef, + mut neighbors: FurthestQueueV, + lc: usize, + ) { + let M = self.params.get_M(lc); + let max_links = self.params.get_M_max(lc); + + neighbors.trim_to_k_nearest(M); + + // Connect all n -> q. + for (n, nq) in neighbors.iter() { + let mut links = graph_store.get_links(n, lc).await; + links.insert(vector_store, q.clone(), nq.clone()).await; + links.trim_to_k_nearest(max_links); + graph_store.set_links(n.clone(), links, lc).await; + } + + // Connect q -> all n. + graph_store.set_links(q.clone(), neighbors, lc).await; + } + + pub fn select_layer(&self, rng: &mut impl RngCore) -> usize { + let p_geom = 1f64 - self.params.get_layer_probability(); + let geom_distr = Geometric::new(p_geom).unwrap(); + + geom_distr.sample(rng) as usize + } + + /// Return a tuple containing a distance-sorted list initialized with the + /// entry point for the HNSW graph search (with distance to the query + /// pre-computed), and the number of search layers of the graph hierarchy, + /// that is, the layer of the entry point plus 1. + /// + /// If no entry point is initialized, returns an empty list and layer 0. + #[allow(non_snake_case)] + async fn search_init>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + ) -> (FurthestQueueV, usize) { + if let Some((entry_point, layer)) = graph_store.get_entry_point().await { + let distance = vector_store.eval_distance(query, &entry_point).await; + + let mut W = FurthestQueueV::::new(); + W.insert(vector_store, entry_point, distance).await; + + (W, layer + 1) + } else { + (FurthestQueue::new(), 0) + } + } + + /// Mutate `W` into the `ef` nearest neighbors of query vector `q` in the + /// given layer using depth-first graph traversal, Terminates when `W` + /// contains vectors which are the nearest to `q` among all traversed + /// vertices and their neighbors. + #[allow(non_snake_case)] + async fn search_layer>( + &self, + vector_store: &mut V, + graph_store: &mut G, + q: &V::QueryRef, + W: &mut FurthestQueueV, + ef: usize, + lc: usize, + ) { + // v: The set of already visited vectors. + let mut v = HashSet::::from_iter(W.iter().map(|(e, _eq)| e.clone())); + + // C: The set of vectors to visit, ordered by increasing distance to the query. + let mut C = NearestQueue::from_furthest_queue(W); + + // fq: The current furthest distance in W. + let (_, mut fq) = W.get_furthest().expect("W cannot be empty").clone(); + + while C.len() > 0 { + let (c, cq) = C.pop_nearest().expect("C cannot be empty").clone(); + + // If the nearest distance to C is greater than the furthest distance in W, then + // we can stop. + if vector_store.less_than(&fq, &cq).await { + break; + } + + // Open the node c and explore its neighbors. + + // Visit all neighbors of c. + let c_links = graph_store.get_links(&c, lc).await; + + // Evaluate the distances of the neighbors to the query, as a batch. + let c_links = { + let e_batch = c_links + .iter() + .map(|(e, _ec)| e.clone()) + .filter(|e| { + // Visit any node at most once. + v.insert(e.clone()) + }) + .collect::>(); + + let distances = vector_store.eval_distance_batch(q, &e_batch).await; + + e_batch + .into_iter() + .zip(distances.into_iter()) + .collect::>() + }; + + for (e, eq) in c_links.into_iter() { + if W.len() == ef { + // When W is full, we decide whether to replace the furthest element. + if vector_store.less_than(&eq, &fq).await { + // Make room for the new better candidate… + W.pop_furthest(); + } else { + // …or ignore the candidate and do not continue on this path. + continue; + } + } + + // Track the new candidate in C so we will continue this path later. + C.insert(vector_store, e.clone(), eq.clone()).await; + + // Track the new candidate as a potential k-nearest. + W.insert(vector_store, e, eq).await; + + // fq stays the furthest distance in W. + (_, fq) = W.get_furthest().expect("W cannot be empty").clone(); + } + } + } + + #[allow(non_snake_case)] + pub async fn search>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + k: usize, + ) -> FurthestQueueV { + let (mut W, layer_count) = self.search_init(vector_store, graph_store, query).await; + + // Search from the top layer down to layer 0 + for lc in (0..layer_count).rev() { + let ef = self.params.get_ef_search(lc); + self.search_layer(vector_store, graph_store, query, &mut W, ef, lc) + .await; + } + + W.trim_to_k_nearest(k); + W + } + + /// Insert `query` into HNSW index represented by `vector_store` and + /// `graph_store`. Return a `V::VectorRef` representing the inserted + /// vector. + pub async fn insert>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + rng: &mut impl RngCore, + ) -> V::VectorRef { + let insertion_layer = self.select_layer(rng); + let (neighbors, set_ep) = self + .search_to_insert(vector_store, graph_store, query, insertion_layer) + .await; + let inserted = vector_store.insert(query).await; + self.insert_from_search_results( + vector_store, + graph_store, + inserted.clone(), + neighbors, + set_ep, + ) + .await; + inserted + } + + /// Conduct the search phase of HNSW insertion of `query` into the graph at + /// a specified insertion layer. Layer search uses the "search" type + /// `ef_constr` parameter(s) for layers above the insertion layer (1 in + /// standard HNSW), and the "insertion" type `ef_constr` parameter(s) for + /// layers below the insertion layer (a single fixed `ef_constr` parameter + /// in standard HNSW). + /// + /// The output is a vector of the nearest neighbors found in each insertion + /// layer, and a boolean indicating if the insertion sets the entry point. + /// Nearest neighbors are provided in the output for each layer in which + /// the query is to be inserted, including empty neighbor lists for + /// insertion in any layers higher than the current entry point. + /// + /// If no entry point is initialized for the index, then the insertion will + /// set `query` as the index entry point. + #[allow(non_snake_case)] + pub async fn search_to_insert>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + insertion_layer: usize, + ) -> (Vec>, bool) { + let mut links = vec![]; + + let (mut W, n_layers) = self.search_init(vector_store, graph_store, query).await; + + // Search from the top layer down to layer 0 + for lc in (0..n_layers).rev() { + let ef = if lc > insertion_layer { + self.params.get_ef_constr_search(lc) + } else { + self.params.get_ef_constr_insert(lc) + }; + self.search_layer(vector_store, graph_store, query, &mut W, ef, lc) + .await; + + // Save links in output only for layers in which query is inserted + if lc <= insertion_layer { + links.push(W.clone()); + } + } + + // We inserted top-down, so reverse to match the layer indices (bottom=0) + links.reverse(); + + // If query is to be inserted at a new highest layer as a new entry + // point, insert additional empty neighborhoods for any new layers + let set_ep = insertion_layer + 1 > n_layers; + for _ in links.len()..insertion_layer + 1 { + links.push(FurthestQueue::new()); + } + debug_assert!(links.len() == insertion_layer + 1); + + (links, set_ep) + } + + /// Insert a vector using the search results from `search_to_insert`, + /// that is the nearest neighbor links at each insertion layer, and a flag + /// indicating whether the vector is to be inserted as the new entry point. + pub async fn insert_from_search_results>( + &self, + vector_store: &mut V, + graph_store: &mut G, + inserted_vector: V::VectorRef, + links: Vec>, + set_ep: bool, + ) { + // If required, set vector as new entry point + if set_ep { + let insertion_layer = links.len() - 1; + graph_store + .set_entry_point(inserted_vector.clone(), insertion_layer) + .await; + } + + // Connect the new vector to its neighbors in each layer. + for (lc, layer_links) in links.into_iter().enumerate().rev() { + self.connect_bidir(vector_store, graph_store, &inserted_vector, layer_links, lc) + .await; + } + } + + pub async fn is_match( + &self, + vector_store: &mut V, + neighbors: &[FurthestQueueV], + ) -> bool { + match neighbors + .first() + .and_then(|bottom_layer| bottom_layer.get_nearest()) + { + None => false, // Empty database. + Some((_, smallest_distance)) => vector_store.is_match(smallest_distance).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aes_prng::AesRng; + use hawk_pack::{ + graph_store::graph_mem::GraphMem, vector_store::lazy_memory_store::LazyMemoryStore, + }; + use rand::SeedableRng; + use tokio; + + #[tokio::test] + async fn test_hnsw_db() { + let vector_store = &mut LazyMemoryStore::new(); + let graph_store = &mut GraphMem::new(); + let rng = &mut AesRng::seed_from_u64(0_u64); + let db = HnswSearcher::default(); + + let queries1 = (0..100) + .map(|raw_query| vector_store.prepare_query(raw_query)) + .collect::>(); + + // Insert the codes. + for query in queries1.iter() { + let insertion_layer = db.select_layer(rng); + let (neighbors, set_ep) = db + .search_to_insert(vector_store, graph_store, query, insertion_layer) + .await; + assert!(!db.is_match(vector_store, &neighbors).await); + // Insert the new vector into the store. + let inserted = vector_store.insert(query).await; + db.insert_from_search_results(vector_store, graph_store, inserted, neighbors, set_ep) + .await; + } + + let queries2 = (101..200) + .map(|raw_query| vector_store.prepare_query(raw_query)) + .collect::>(); + + // Insert the codes with helper function + for query in queries2.iter() { + db.insert(vector_store, graph_store, query, rng).await; + } + + // Search for the same codes and find matches. + for query in queries1.iter().chain(queries2.iter()) { + let neighbors = db.search(vector_store, graph_store, query, 1).await; + assert!(db.is_match(vector_store, &[neighbors]).await); + } + } +} diff --git a/iris-mpc-cpu/src/lib.rs b/iris-mpc-cpu/src/lib.rs index bf4a96011..94d4eaf6b 100644 --- a/iris-mpc-cpu/src/lib.rs +++ b/iris-mpc-cpu/src/lib.rs @@ -1,6 +1,7 @@ pub mod database_generators; pub mod execution; pub mod hawkers; +pub mod hnsw; pub(crate) mod network; #[rustfmt::skip] pub(crate) mod proto_generated; diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index 5e67542ff..1e1a0542a 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -341,9 +341,9 @@ mod tests { use crate::{ execution::{local::generate_local_identities, player::Role}, hawkers::aby3_store::Aby3Store, + hnsw::HnswSearcher, }; use aes_prng::AesRng; - use hawk_pack::HawkSearcher; use rand::SeedableRng; use tokio::task::JoinSet; use tracing_test::traced_test; @@ -570,7 +570,7 @@ mod tests { async fn test_hnsw_local() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = HawkSearcher::default(); + let searcher = HnswSearcher::default(); let mut vectors_and_graphs = Aby3Store::shared_random_setup( &mut rng, database_size, diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs index 471de784a..cf22767a3 100644 --- a/iris-mpc-cpu/src/py_bindings/hnsw.rs +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -1,6 +1,9 @@ use super::plaintext_store::Base64IrisCode; -use crate::hawkers::plaintext_store::{PlaintextStore, PointId}; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use crate::{ + hawkers::plaintext_store::{PlaintextStore, PointId}, + hnsw::HnswSearcher, +}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; use rand::rngs::ThreadRng; use serde_json::{self, Deserializer}; @@ -8,7 +11,7 @@ use std::{fs::File, io::BufReader}; pub fn search( query: IrisCode, - searcher: &HawkSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> (PointId, f64) { @@ -28,7 +31,7 @@ pub fn search( // TODO could instead take iterator of IrisCodes to make more flexible pub fn insert( iris: IrisCode, - searcher: &HawkSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -46,7 +49,7 @@ pub fn insert( } pub fn insert_uniform_random( - searcher: &HawkSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -58,7 +61,7 @@ pub fn insert_uniform_random( pub fn fill_uniform_random( num: usize, - searcher: &HawkSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { @@ -84,7 +87,7 @@ pub fn fill_uniform_random( pub fn fill_from_ndjson_file( filename: &str, limit: Option, - searcher: &HawkSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { diff --git a/iris-mpc-py/README.md b/iris-mpc-py/README.md index aea78cecb..5cd591589 100644 --- a/iris-mpc-py/README.md +++ b/iris-mpc-py/README.md @@ -21,9 +21,9 @@ See the [Maturin User Guide Tutorial](https://www.maturin.rs/tutorial#build-and- Once successfully installed, the native rust module `iris_mpc_py` can be imported in your Python environment as usual with `import iris_mpc_py`. Example usage: ```python -from iris_mpc_py import PyHawkSearcher, PyPlaintextStore, PyGraphStore, PyIrisCode +from iris_mpc_py import PyHnswSearcher, PyPlaintextStore, PyGraphStore, PyIrisCode -hnsw = PyHawkSearcher(32, 64, 32) # M, ef_constr, ef_search +hnsw = PyHnswSearcher(32, 64, 32) # M, ef_constr, ef_search vector = PyPlaintextStore() graph = PyGraphStore() @@ -45,7 +45,7 @@ hnsw.write_to_json("searcher.json") vector.write_to_ndjson("vector.ndjson") graph.write_to_bin("graph.dat") -hnsw2 = PyHawkSearcher.read_from_json("searcher.json") +hnsw2 = PyHnswSearcher.read_from_json("searcher.json") vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson") graph2 = PyGraphStore.read_from_bin("graph.dat") ``` @@ -61,7 +61,7 @@ graph = PyGraphStore.read_from_bin("graph.dat") Second, to construct an HNSW index dynamically from streamed database entries: ```python -hnsw = PyHawkSearcher(32, 64, 32) +hnsw = PyHnswSearcher(32, 64, 32) vector = PyPlaintextStore() graph = PyGraphStore() hnsw.fill_from_ndjson_file("large_vector_database.ndjson", vector, graph, 10000) diff --git a/iris-mpc-py/examples-py/.gitignore b/iris-mpc-py/examples-py/.gitignore new file mode 100644 index 000000000..6320cd248 --- /dev/null +++ b/iris-mpc-py/examples-py/.gitignore @@ -0,0 +1 @@ +data \ No newline at end of file diff --git a/iris-mpc-py/examples-py/test_integration.py b/iris-mpc-py/examples-py/test_integration.py index d22bad8ee..e89757f55 100644 --- a/iris-mpc-py/examples-py/test_integration.py +++ b/iris-mpc-py/examples-py/test_integration.py @@ -1,4 +1,10 @@ -from iris_mpc_py import PyIrisCode, PyPlaintextStore, PyGraphStore, PyHawkSearcher +import os +from iris_mpc_py import PyIrisCode, PyPlaintextStore, PyGraphStore, PyHnswSearcher + +vector_filename = "./data/vector.ndjson" +graph_filename = "./data/graph1.dat" +if not os.path.exists("./data/"): + os.makedirs("./data/") print("Generating 100k uniform random iris codes...") vector_init = PyPlaintextStore() @@ -9,13 +15,13 @@ # write vector store to file print("Writing vector store to file...") -vector_init.write_to_ndjson("vector.ndjson") +vector_init.write_to_ndjson(vector_filename) print("Generating HNSW graphs for 10k imported iris codes...") -hnsw = PyHawkSearcher(32, 64, 32) +hnsw = PyHnswSearcher(32, 64, 32) vector1 = PyPlaintextStore() graph1 = PyGraphStore() -hnsw.fill_from_ndjson_file("vector.ndjson", vector1, graph1, 10000) +hnsw.fill_from_ndjson_file(vector_filename, vector1, graph1, 10000) print("Imported length:", vector1.len()) @@ -27,11 +33,11 @@ # write graph store to file print("Writing graph store to file...") -graph1.write_to_bin("graph1.dat") +graph1.write_to_bin(graph_filename) # read HNSW graphs from disk print("Reading vector and graph stores from file...") -vector2 = PyPlaintextStore.read_from_ndjson("vector.ndjson", 10000) -graph2 = PyGraphStore.read_from_bin("graph1.dat") +vector2 = PyPlaintextStore.read_from_ndjson(vector_filename, 10000) +graph2 = PyGraphStore.read_from_bin(graph_filename) print("Search for random query iris code:", hnsw.search(query, vector2, graph2)) diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs index 1d154346a..049b61b1f 100644 --- a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs +++ b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs @@ -1,18 +1,17 @@ use super::{graph_store::PyGraphStore, iris_code::PyIrisCode, plaintext_store::PyPlaintextStore}; -use hawk_pack::{ - hawk_searcher::{HawkerParams, N_PARAM_LAYERS}, - HawkSearcher, +use iris_mpc_cpu::{ + hnsw::searcher::{HnswParams, HnswSearcher, N_PARAM_LAYERS}, + py_bindings, }; -use iris_mpc_cpu::py_bindings; use pyo3::{exceptions::PyIOError, prelude::*}; #[pyclass] #[derive(Clone, Default)] -pub struct PyHawkSearcher(pub HawkSearcher); +pub struct PyHnswSearcher(pub HnswSearcher); #[pymethods] #[allow(non_snake_case)] -impl PyHawkSearcher { +impl PyHnswSearcher { #[new] pub fn new(M: usize, ef_constr: usize, ef_search: usize) -> Self { Self::new_standard(ef_constr, ef_search, M) @@ -20,17 +19,17 @@ impl PyHawkSearcher { #[staticmethod] pub fn new_standard(M: usize, ef_constr: usize, ef_search: usize) -> Self { - let params = HawkerParams::new(ef_constr, ef_search, M); - Self(HawkSearcher { params }) + let params = HnswParams::new(ef_constr, ef_search, M); + Self(HnswSearcher { params }) } #[staticmethod] pub fn new_uniform(M: usize, ef: usize) -> Self { - let params = HawkerParams::new_uniform(ef, M); - Self(HawkSearcher { params }) + let params = HnswParams::new_uniform(ef, M); + Self(HnswSearcher { params }) } - /// Construct `HawkSearcher` with fully general parameters, specifying the + /// Construct `HnswSearcher` with fully general parameters, specifying the /// values of various parameters used during construction and search at /// different levels of the graph hierarchy. #[staticmethod] @@ -42,7 +41,7 @@ impl PyHawkSearcher { ef_search: [usize; N_PARAM_LAYERS], layer_probability: f64, ) -> Self { - let params = HawkerParams { + let params = HnswParams { M, M_max, ef_constr_search, @@ -50,7 +49,7 @@ impl PyHawkSearcher { ef_search, layer_probability, }; - Self(HawkSearcher { params }) + Self(HnswSearcher { params }) } pub fn insert( diff --git a/iris-mpc-py/src/py_hnsw/pymodule.rs b/iris-mpc-py/src/py_hnsw/pymodule.rs index b0ceae8e3..d93d641aa 100644 --- a/iris-mpc-py/src/py_hnsw/pymodule.rs +++ b/iris-mpc-py/src/py_hnsw/pymodule.rs @@ -1,5 +1,5 @@ use super::pyclasses::{ - graph_store::PyGraphStore, hawk_searcher::PyHawkSearcher, iris_code::PyIrisCode, + graph_store::PyGraphStore, hawk_searcher::PyHnswSearcher, iris_code::PyIrisCode, iris_code_array::PyIrisCodeArray, plaintext_store::PyPlaintextStore, }; use pyo3::prelude::*; @@ -10,6 +10,6 @@ fn iris_mpc_py(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; + m.add_class::()?; Ok(()) }