From 8979037e4d8df5457a23b0e464fa2915ca1ccd45 Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Tue, 10 Dec 2024 15:21:17 -0700 Subject: [PATCH 1/7] Port local copy of HawkSearcher for application specific HNSW changes --- Cargo.lock | 1 + iris-mpc-cpu/Cargo.toml | 3 +- iris-mpc-cpu/benches/hnsw.rs | 15 +- iris-mpc-cpu/examples/hnsw-ex.rs | 6 +- iris-mpc-cpu/src/hawkers/galois_store.rs | 18 +- iris-mpc-cpu/src/hawkers/iris_searcher.rs | 487 ++++++++++++++++++ iris-mpc-cpu/src/hawkers/mod.rs | 1 + iris-mpc-cpu/src/hawkers/plaintext_store.rs | 9 +- iris-mpc-cpu/src/network/grpc.rs | 5 +- iris-mpc-cpu/src/py_bindings/hnsw.rs | 17 +- .../src/py_hnsw/pyclasses/hawk_searcher.rs | 23 +- 11 files changed, 541 insertions(+), 44 deletions(-) create mode 100644 iris-mpc-cpu/src/hawkers/iris_searcher.rs diff --git a/Cargo.lock b/Cargo.lock index fee4658d5..066be7d05 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2755,6 +2755,7 @@ dependencies = [ "num-traits", "prost", "rand", + "rand_distr", "rstest", "serde", "serde_json", diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 05292c77d..e7547ca84 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -37,6 +37,7 @@ tracing.workspace = true tracing-subscriber.workspace = true tracing-test = "0.2.5" uuid.workspace = true +rand_distr = "0.4.3" [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } @@ -53,4 +54,4 @@ name = "hnsw-ex" [[bin]] name = "local_hnsw" -path = "bin/local_hnsw.rs" \ No newline at end of file +path = "bin/local_hnsw.rs" diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index 48cea2651..9d3534c8c 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -1,11 +1,14 @@ use aes_prng::AesRng; use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode}; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; use iris_mpc_cpu::{ database_generators::{create_random_sharing, generate_galois_iris_shares}, execution::local::LocalRuntime, - hawkers::{galois_store::LocalNetAby3NgStoreProtocol, plaintext_store::PlaintextStore}, + hawkers::{ + galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::IrisSearcher, + plaintext_store::PlaintextStore, + }, protocol::ops::{ batch_signed_lift_vec, cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3, }, @@ -28,7 +31,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); for _ in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); @@ -44,7 +47,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || (vector.clone(), graph.clone()), |(mut db_vectors, mut graph)| async move { - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let query = db_vectors.prepare_query(on_the_fly_query); @@ -185,7 +188,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); @@ -219,7 +222,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); diff --git a/iris-mpc-cpu/examples/hnsw-ex.rs b/iris-mpc-cpu/examples/hnsw-ex.rs index 71c925028..debc53aa8 100644 --- a/iris-mpc-cpu/examples/hnsw-ex.rs +++ b/iris-mpc-cpu/examples/hnsw-ex.rs @@ -1,7 +1,7 @@ use aes_prng::AesRng; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; -use iris_mpc_cpu::hawkers::plaintext_store::PlaintextStore; +use iris_mpc_cpu::hawkers::{iris_searcher::IrisSearcher, plaintext_store::PlaintextStore}; use rand::SeedableRng; const DATABASE_SIZE: usize = 1_000; @@ -16,7 +16,7 @@ fn main() { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); for idx in 0..DATABASE_SIZE { let raw_query = IrisCode::random_rng(&mut rng); diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index 520570821..385149a07 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -1,4 +1,4 @@ -use super::plaintext_store::PlaintextStore; +use super::{iris_searcher::IrisSearcher, plaintext_store::PlaintextStore}; use crate::{ database_generators::{generate_galois_iris_shares, GaloisRingSharedIris}, execution::{ @@ -21,7 +21,7 @@ use aes_prng::AesRng; use hawk_pack::{ data_structures::queue::FurthestQueue, graph_store::{graph_mem::Layer, GraphMem}, - GraphStore, HawkSearcher, VectorStore, + GraphStore, VectorStore, }; use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; use rand::{CryptoRng, RngCore, SeedableRng}; @@ -470,7 +470,7 @@ impl LocalNetAby3NgStoreProtocol { .collect::>(); jobs.spawn(async move { let mut graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); // insert queries for query in queries.iter() { searcher @@ -510,9 +510,11 @@ impl LocalNetAby3NgStoreProtocol { #[cfg(test)] mod tests { use super::*; - use crate::database_generators::generate_galois_iris_shares; + use crate::{ + database_generators::generate_galois_iris_shares, hawkers::iris_searcher::IrisSearcher, + }; use aes_prng::AesRng; - use hawk_pack::{graph_store::GraphMem, HawkSearcher}; + use hawk_pack::graph_store::GraphMem; use itertools::Itertools; use rand::SeedableRng; use tracing_test::traced_test; @@ -541,7 +543,7 @@ mod tests { let mut rng = rng.clone(); jobs.spawn(async move { let mut aby3_graph = GraphMem::new(); - let db = HawkSearcher::default(); + let db = IrisSearcher::default(); let mut inserted = vec![]; // insert queries @@ -601,7 +603,7 @@ mod tests { { assert_eq!(v_from_scratch.storage.points, premade_v.storage.points); } - let hawk_searcher = HawkSearcher::default(); + let hawk_searcher = IrisSearcher::default(); for i in 0..database_size { let cleartext_neighbors = hawk_searcher @@ -748,7 +750,7 @@ mod tests { async fn test_gr_scratch_hnsw() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( &mut rng, database_size, diff --git a/iris-mpc-cpu/src/hawkers/iris_searcher.rs b/iris-mpc-cpu/src/hawkers/iris_searcher.rs new file mode 100644 index 000000000..9dfea67ae --- /dev/null +++ b/iris-mpc-cpu/src/hawkers/iris_searcher.rs @@ -0,0 +1,487 @@ +//* Implementation of HNSW algorithm for k-nearest-neighbor search over iris +//* biometric templates with high-latency MPC comparison operations. Based on +//* the `HawkSearcher` class of the hawk-pack crate: +//* +//* https://github.com/Inversed-Tech/hawk-pack/ + +pub use hawk_pack::data_structures::queue::{ + FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, +}; +use hawk_pack::{GraphStore, VectorStore}; +use rand::RngCore; +use rand_distr::{Distribution, Geometric}; +use serde::{Deserialize, Serialize}; +use std::collections::HashSet; + +// specify construction and search parameters by layer up to this value minus 1 +// any higher layers will use the last set of parameters +pub const N_PARAM_LAYERS: usize = 5; + +#[allow(non_snake_case)] +#[derive(PartialEq, Clone, Serialize, Deserialize)] +pub struct HnswParams { + pub M: [usize; N_PARAM_LAYERS], // number of neighbors for insertion + pub M_max: [usize; N_PARAM_LAYERS], // maximum number of neighbors + pub ef_constr_search: [usize; N_PARAM_LAYERS], // ef_constr for search layers + pub ef_constr_insert: [usize; N_PARAM_LAYERS], // ef_constr for insertion layers + pub ef_search: [usize; N_PARAM_LAYERS], // ef for search + pub layer_probability: f64, /* p for geometric distribution of layer + * densities */ +} + +#[allow(non_snake_case, clippy::too_many_arguments)] +impl HnswParams { + /// Construct a `Params` object corresponding to parameter configuration + /// providing the functionality described in the original HNSW paper: + /// - ef_construction exploration factor used for insertion layers + /// - ef_search exploration factor used for layer 0 in search + /// - higher layers in both insertion and search use exploration factor 1, + /// representing simple greedy search + /// - vertex degrees bounded by M_max = M in positive layer graphs + /// - vertex degrees bounded by M_max0 = 2*M in layer 0 graph + /// - m_L = 1 / ln(M) so that layer density decreases by a factor of M at + /// each successive hierarchical layer + pub fn new(ef_construction: usize, ef_search: usize, M: usize) -> Self { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [1usize; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef_construction; N_PARAM_LAYERS]; + let mut ef_search_arr = [1usize; N_PARAM_LAYERS]; + ef_search_arr[0] = ef_search; + let layer_probability = (M as f64).recip(); + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + + /// Parameter configuration using fixed exploration factor for all layer + /// search operations, both for insertion and for search. + pub fn new_uniform(ef: usize, M: usize) -> Self { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [ef; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef; N_PARAM_LAYERS]; + let ef_search_arr = [ef; N_PARAM_LAYERS]; + let layer_probability = (M as f64).recip(); + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + + /// Compute the parameter m_L associated with a geometric distribution + /// parameter q describing the random layer of newly inserted graph nodes. + /// + /// E.g. for graph hierarchy where each layer has a factor of 32 fewer + /// entries than the last, the `layer_probability` input is 1/32. + pub fn m_L_from_layer_probability(layer_probability: f64) -> f64 { + -layer_probability.ln().recip() + } + + /// Compute the parameter q for the geometric distribution used to select + /// the insertion layer for newly inserted graph nodes, from the parameter + /// m_L of the original HNSW paper. + pub fn layer_probability_from_m_L(m_L: f64) -> f64 { + (-m_L.recip()).exp() + } + + pub fn get_M(&self, lc: usize) -> usize { + Self::get_val(&self.M, lc) + } + + pub fn get_M_max(&self, lc: usize) -> usize { + Self::get_val(&self.M_max, lc) + } + + pub fn get_ef_constr_search(&self, lc: usize) -> usize { + Self::get_val(&self.ef_constr_search, lc) + } + + pub fn get_ef_constr_insert(&self, lc: usize) -> usize { + Self::get_val(&self.ef_constr_insert, lc) + } + + pub fn get_ef_search(&self, lc: usize) -> usize { + Self::get_val(&self.ef_search, lc) + } + + pub fn get_layer_probability(&self) -> f64 { + self.layer_probability + } + + pub fn get_m_L(&self) -> f64 { + Self::m_L_from_layer_probability(self.layer_probability) + } + + #[inline(always)] + /// Select value at index `lc` from the input fixed-size array, or the last + /// index of this array if `lc` is larger than the array size. + fn get_val(arr: &[usize; N_PARAM_LAYERS], lc: usize) -> usize { + arr[lc.min(N_PARAM_LAYERS - 1)] + } +} + +/// An implementation of the HNSW algorithm. +/// +/// Operations on vectors are delegated to a VectorStore. +/// Operations on the graph are delegate to a GraphStore. +#[derive(Clone, Serialize, Deserialize)] +pub struct IrisSearcher { + pub params: HnswParams, +} + +// TODO remove default value; this varies too much between applications +// to make sense to specify something "obvious" +impl Default for IrisSearcher { + fn default() -> Self { + IrisSearcher { + params: HnswParams::new(64, 32, 32), + } + } +} + +#[allow(non_snake_case)] +impl IrisSearcher { + async fn connect_bidir>( + &self, + vector_store: &mut V, + graph_store: &mut G, + q: &V::VectorRef, + mut neighbors: FurthestQueueV, + lc: usize, + ) { + let M = self.params.get_M(lc); + let max_links = self.params.get_M_max(lc); + + neighbors.trim_to_k_nearest(M); + + // Connect all n -> q. + for (n, nq) in neighbors.iter() { + let mut links = graph_store.get_links(n, lc).await; + links.insert(vector_store, q.clone(), nq.clone()).await; + links.trim_to_k_nearest(max_links); + graph_store.set_links(n.clone(), links, lc).await; + } + + // Connect q -> all n. + graph_store.set_links(q.clone(), neighbors, lc).await; + } + + pub fn select_layer(&self, rng: &mut impl RngCore) -> usize { + let p_geom = 1f64 - self.params.get_layer_probability(); + let geom_distr = Geometric::new(p_geom).unwrap(); + + geom_distr.sample(rng) as usize + } + + /// Return a tuple containing a distance-sorted list initialized with the + /// entry point for the HNSW graph search (with distance to the query + /// pre-computed), and the number of search layers of the graph hierarchy, + /// that is, the layer of the entry point plus 1. + /// + /// If no entry point is initialized, returns an empty list and layer 0. + #[allow(non_snake_case)] + async fn search_init>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + ) -> (FurthestQueueV, usize) { + if let Some((entry_point, layer)) = graph_store.get_entry_point().await { + let distance = vector_store.eval_distance(query, &entry_point).await; + + let mut W = FurthestQueueV::::new(); + W.insert(vector_store, entry_point, distance).await; + + (W, layer + 1) + } else { + (FurthestQueue::new(), 0) + } + } + + /// Mutate `W` into the `ef` nearest neighbors of query vector `q` in the + /// given layer using depth-first graph traversal, Terminates when `W` + /// contains vectors which are the nearest to `q` among all traversed + /// vertices and their neighbors. + #[allow(non_snake_case)] + async fn search_layer>( + &self, + vector_store: &mut V, + graph_store: &mut G, + q: &V::QueryRef, + W: &mut FurthestQueueV, + ef: usize, + lc: usize, + ) { + // v: The set of already visited vectors. + let mut v = HashSet::::from_iter(W.iter().map(|(e, _eq)| e.clone())); + + // C: The set of vectors to visit, ordered by increasing distance to the query. + let mut C = NearestQueue::from_furthest_queue(W); + + // fq: The current furthest distance in W. + let (_, mut fq) = W.get_furthest().expect("W cannot be empty").clone(); + + while C.len() > 0 { + let (c, cq) = C.pop_nearest().expect("C cannot be empty").clone(); + + // If the nearest distance to C is greater than the furthest distance in W, then + // we can stop. + if vector_store.less_than(&fq, &cq).await { + break; + } + + // Visit all neighbors of c. + let c_links = graph_store.get_links(&c, lc).await; + + // Evaluate the distances of the neighbors to the query, as a batch. + let c_links = { + let e_batch = c_links + .iter() + .map(|(e, _ec)| e.clone()) + .filter(|e| { + // Visit any node at most once. + v.insert(e.clone()) + }) + .collect::>(); + + let distances = vector_store.eval_distance_batch(q, &e_batch).await; + + e_batch + .into_iter() + .zip(distances.into_iter()) + .collect::>() + }; + + for (e, eq) in c_links.into_iter() { + if W.len() == ef { + // When W is full, we decide whether to replace the furthest element. + if vector_store.less_than(&eq, &fq).await { + // Make room for the new better candidate… + W.pop_furthest(); + } else { + // …or ignore the candidate and do not continue on this path. + continue; + } + } + + // Track the new candidate in C so we will continue this path later. + C.insert(vector_store, e.clone(), eq.clone()).await; + + // Track the new candidate as a potential k-nearest. + W.insert(vector_store, e, eq).await; + + // fq stays the furthest distance in W. + (_, fq) = W.get_furthest().expect("W cannot be empty").clone(); + } + } + } + + #[allow(non_snake_case)] + pub async fn search>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + k: usize, + ) -> FurthestQueueV { + let (mut W, layer_count) = self.search_init(vector_store, graph_store, query).await; + + // Search from the top layer down to layer 0 + for lc in (0..layer_count).rev() { + let ef = self.params.get_ef_search(lc); + self.search_layer(vector_store, graph_store, query, &mut W, ef, lc) + .await; + } + + W.trim_to_k_nearest(k); + W + } + + /// Insert `query` into HNSW index represented by `vector_store` and + /// `graph_store`. Return a `V::VectorRef` representing the inserted + /// vector. + pub async fn insert>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + rng: &mut impl RngCore, + ) -> V::VectorRef { + let insertion_layer = self.select_layer(rng); + let (neighbors, set_ep) = self + .search_to_insert(vector_store, graph_store, query, insertion_layer) + .await; + let inserted = vector_store.insert(query).await; + self.insert_from_search_results( + vector_store, + graph_store, + inserted.clone(), + neighbors, + set_ep, + ) + .await; + inserted + } + + /// Conduct the search phase of HNSW insertion of `query` into the graph at + /// a specified insertion layer. Layer search uses the "search" type + /// `ef_constr` parameter(s) for layers above the insertion layer (1 in + /// standard HNSW), and the "insertion" type `ef_constr` parameter(s) for + /// layers below the insertion layer (a single fixed `ef_constr` parameter + /// in standard HNSW). + /// + /// The output is a vector of the nearest neighbors found in each insertion + /// layer, and a boolean indicating if the insertion sets the entry point. + /// Nearest neighbors are provided in the output for each layer in which + /// the query is to be inserted, including empty neighbor lists for + /// insertion in any layers higher than the current entry point. + /// + /// If no entry point is initialized for the index, then the insertion will + /// set `query` as the index entry point. + #[allow(non_snake_case)] + pub async fn search_to_insert>( + &self, + vector_store: &mut V, + graph_store: &mut G, + query: &V::QueryRef, + insertion_layer: usize, + ) -> (Vec>, bool) { + let mut links = vec![]; + + let (mut W, n_layers) = self.search_init(vector_store, graph_store, query).await; + + // Search from the top layer down to layer 0 + for lc in (0..n_layers).rev() { + let ef = if lc > insertion_layer { + self.params.get_ef_constr_search(lc) + } else { + self.params.get_ef_constr_insert(lc) + }; + self.search_layer(vector_store, graph_store, query, &mut W, ef, lc) + .await; + + // Save links in output only for layers in which query is inserted + if lc <= insertion_layer { + links.push(W.clone()); + } + } + + // We inserted top-down, so reverse to match the layer indices (bottom=0) + links.reverse(); + + // If query is to be inserted at a new highest layer as a new entry + // point, insert additional empty neighborhoods for any new layers + let set_ep = insertion_layer + 1 > n_layers; + for _ in links.len()..insertion_layer + 1 { + links.push(FurthestQueue::new()); + } + debug_assert!(links.len() == insertion_layer + 1); + + (links, set_ep) + } + + /// Insert a vector using the search results from `search_to_insert`, + /// that is the nearest neighbor links at each insertion layer, and a flag + /// indicating whether the vector is to be inserted as the new entry point. + pub async fn insert_from_search_results>( + &self, + vector_store: &mut V, + graph_store: &mut G, + inserted_vector: V::VectorRef, + links: Vec>, + set_ep: bool, + ) { + // If required, set vector as new entry point + if set_ep { + let insertion_layer = links.len() - 1; + graph_store + .set_entry_point(inserted_vector.clone(), insertion_layer) + .await; + } + + // Connect the new vector to its neighbors in each layer. + for (lc, layer_links) in links.into_iter().enumerate().rev() { + self.connect_bidir(vector_store, graph_store, &inserted_vector, layer_links, lc) + .await; + } + } + + pub async fn is_match( + &self, + vector_store: &mut V, + neighbors: &[FurthestQueueV], + ) -> bool { + match neighbors + .first() + .and_then(|bottom_layer| bottom_layer.get_nearest()) + { + None => false, // Empty database. + Some((_, smallest_distance)) => vector_store.is_match(smallest_distance).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use aes_prng::AesRng; + use hawk_pack::{ + graph_store::graph_mem::GraphMem, vector_store::lazy_memory_store::LazyMemoryStore, + }; + use rand::SeedableRng; + use tokio; + + #[tokio::test] + async fn test_hnsw_db() { + let vector_store = &mut LazyMemoryStore::new(); + let graph_store = &mut GraphMem::new(); + let rng = &mut AesRng::seed_from_u64(0_u64); + let db = IrisSearcher::default(); + + let queries1 = (0..100) + .map(|raw_query| vector_store.prepare_query(raw_query)) + .collect::>(); + + // Insert the codes. + for query in queries1.iter() { + let insertion_layer = db.select_layer(rng); + let (neighbors, set_ep) = db + .search_to_insert(vector_store, graph_store, query, insertion_layer) + .await; + assert!(!db.is_match(vector_store, &neighbors).await); + // Insert the new vector into the store. + let inserted = vector_store.insert(query).await; + db.insert_from_search_results(vector_store, graph_store, inserted, neighbors, set_ep) + .await; + } + + let queries2 = (101..200) + .map(|raw_query| vector_store.prepare_query(raw_query)) + .collect::>(); + + // Insert the codes with helper function + for query in queries2.iter() { + db.insert(vector_store, graph_store, query, rng).await; + } + + // Search for the same codes and find matches. + for query in queries1.iter().chain(queries2.iter()) { + let neighbors = db.search(vector_store, graph_store, query, 1).await; + assert!(db.is_match(vector_store, &[neighbors]).await); + } + } +} diff --git a/iris-mpc-cpu/src/hawkers/mod.rs b/iris-mpc-cpu/src/hawkers/mod.rs index e2ec49a26..0a1bb4e52 100644 --- a/iris-mpc-cpu/src/hawkers/mod.rs +++ b/iris-mpc-cpu/src/hawkers/mod.rs @@ -1,2 +1,3 @@ pub mod galois_store; +pub mod iris_searcher; pub mod plaintext_store; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 2e10301e5..27cc29038 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,5 +1,6 @@ +use crate::hawkers::iris_searcher::IrisSearcher; use aes_prng::AesRng; -use hawk_pack::{graph_store::GraphMem, HawkSearcher, VectorStore}; +use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ db::IrisDB, iris::{IrisCode, MATCH_THRESHOLD_RATIO}, @@ -148,7 +149,7 @@ impl PlaintextStore { let mut plaintext_vector_store = PlaintextStore::default(); let mut plaintext_graph_store = GraphMem::new(); - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); for raw_query in cleartext_database.iter() { let query = plaintext_vector_store.prepare_query(raw_query.clone()); @@ -173,8 +174,8 @@ impl PlaintextStore { #[cfg(test)] mod tests { use super::*; + use crate::hawkers::iris_searcher::IrisSearcher; use aes_prng::AesRng; - use hawk_pack::HawkSearcher; use iris_mpc_common::iris_db::db::IrisDB; use rand::SeedableRng; use tracing_test::traced_test; @@ -256,7 +257,7 @@ mod tests { async fn test_plaintext_hnsw_matcher() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 1; - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let (_, mut ptxt_vector, mut ptxt_graph) = PlaintextStore::create_random(&mut rng, database_size) .await diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index bd3185532..ce9be679d 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -340,10 +340,9 @@ mod tests { use super::*; use crate::{ execution::{local::generate_local_identities, player::Role}, - hawkers::galois_store::LocalNetAby3NgStoreProtocol, + hawkers::{galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::IrisSearcher}, }; use aes_prng::AesRng; - use hawk_pack::HawkSearcher; use rand::SeedableRng; use tokio::task::JoinSet; use tracing_test::traced_test; @@ -570,7 +569,7 @@ mod tests { async fn test_hnsw_local() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = HawkSearcher::default(); + let searcher = IrisSearcher::default(); let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( &mut rng, database_size, diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs index 471de784a..da18a10c3 100644 --- a/iris-mpc-cpu/src/py_bindings/hnsw.rs +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -1,6 +1,9 @@ use super::plaintext_store::Base64IrisCode; -use crate::hawkers::plaintext_store::{PlaintextStore, PointId}; -use hawk_pack::{graph_store::GraphMem, HawkSearcher}; +use crate::hawkers::{ + iris_searcher::IrisSearcher, + plaintext_store::{PlaintextStore, PointId}, +}; +use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; use rand::rngs::ThreadRng; use serde_json::{self, Deserializer}; @@ -8,7 +11,7 @@ use std::{fs::File, io::BufReader}; pub fn search( query: IrisCode, - searcher: &HawkSearcher, + searcher: &IrisSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> (PointId, f64) { @@ -28,7 +31,7 @@ pub fn search( // TODO could instead take iterator of IrisCodes to make more flexible pub fn insert( iris: IrisCode, - searcher: &HawkSearcher, + searcher: &IrisSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -46,7 +49,7 @@ pub fn insert( } pub fn insert_uniform_random( - searcher: &HawkSearcher, + searcher: &IrisSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -58,7 +61,7 @@ pub fn insert_uniform_random( pub fn fill_uniform_random( num: usize, - searcher: &HawkSearcher, + searcher: &IrisSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { @@ -84,7 +87,7 @@ pub fn fill_uniform_random( pub fn fill_from_ndjson_file( filename: &str, limit: Option, - searcher: &HawkSearcher, + searcher: &IrisSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs index 1d154346a..c31a711cb 100644 --- a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs +++ b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs @@ -1,14 +1,13 @@ use super::{graph_store::PyGraphStore, iris_code::PyIrisCode, plaintext_store::PyPlaintextStore}; -use hawk_pack::{ - hawk_searcher::{HawkerParams, N_PARAM_LAYERS}, - HawkSearcher, +use iris_mpc_cpu::{ + hawkers::iris_searcher::{HnswParams, IrisSearcher, N_PARAM_LAYERS}, + py_bindings, }; -use iris_mpc_cpu::py_bindings; use pyo3::{exceptions::PyIOError, prelude::*}; #[pyclass] #[derive(Clone, Default)] -pub struct PyHawkSearcher(pub HawkSearcher); +pub struct PyHawkSearcher(pub IrisSearcher); #[pymethods] #[allow(non_snake_case)] @@ -20,17 +19,17 @@ impl PyHawkSearcher { #[staticmethod] pub fn new_standard(M: usize, ef_constr: usize, ef_search: usize) -> Self { - let params = HawkerParams::new(ef_constr, ef_search, M); - Self(HawkSearcher { params }) + let params = HnswParams::new(ef_constr, ef_search, M); + Self(IrisSearcher { params }) } #[staticmethod] pub fn new_uniform(M: usize, ef: usize) -> Self { - let params = HawkerParams::new_uniform(ef, M); - Self(HawkSearcher { params }) + let params = HnswParams::new_uniform(ef, M); + Self(IrisSearcher { params }) } - /// Construct `HawkSearcher` with fully general parameters, specifying the + /// Construct `IrisSearcher` with fully general parameters, specifying the /// values of various parameters used during construction and search at /// different levels of the graph hierarchy. #[staticmethod] @@ -42,7 +41,7 @@ impl PyHawkSearcher { ef_search: [usize; N_PARAM_LAYERS], layer_probability: f64, ) -> Self { - let params = HawkerParams { + let params = HnswParams { M, M_max, ef_constr_search, @@ -50,7 +49,7 @@ impl PyHawkSearcher { ef_search, layer_probability, }; - Self(HawkSearcher { params }) + Self(IrisSearcher { params }) } pub fn insert( From 4a0cfa0c772bb697d319c7d7b1272a27bcfa7710 Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Wed, 11 Dec 2024 12:06:13 -0700 Subject: [PATCH 2/7] Rename local HawkSearcher copy to descriptive HnswSearcher --- iris-mpc-cpu/benches/hnsw.rs | 10 +++++----- iris-mpc-cpu/examples/hnsw-ex.rs | 4 ++-- iris-mpc-cpu/src/hawkers/galois_store.rs | 12 ++++++------ iris-mpc-cpu/src/hawkers/iris_searcher.rs | 10 +++++----- iris-mpc-cpu/src/hawkers/plaintext_store.rs | 10 +++++----- iris-mpc-cpu/src/network/grpc.rs | 4 ++-- iris-mpc-cpu/src/py_bindings/hnsw.rs | 12 ++++++------ iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs | 10 +++++----- 8 files changed, 36 insertions(+), 36 deletions(-) diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index fc15a0530..bf4d3cb10 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -6,7 +6,7 @@ use iris_mpc_cpu::{ database_generators::{create_random_sharing, generate_galois_iris_shares}, execution::local::LocalRuntime, hawkers::{ - galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::IrisSearcher, + galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::HnswSearcher, plaintext_store::PlaintextStore, }, protocol::ops::{ @@ -31,7 +31,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); for _ in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); @@ -47,7 +47,7 @@ fn bench_plaintext_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || (vector.clone(), graph.clone()), |(mut db_vectors, mut graph)| async move { - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let query = db_vectors.prepare_query(on_the_fly_query); @@ -204,7 +204,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); @@ -238,7 +238,7 @@ fn bench_gr_ready_made_hnsw(c: &mut Criterion) { b.to_async(&rt).iter_batched( || secret_searcher.clone(), |vectors_graphs| async move { - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let mut rng = AesRng::seed_from_u64(0_u64); let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone(); let raw_query = generate_galois_iris_shares(&mut rng, on_the_fly_query); diff --git a/iris-mpc-cpu/examples/hnsw-ex.rs b/iris-mpc-cpu/examples/hnsw-ex.rs index debc53aa8..1295869df 100644 --- a/iris-mpc-cpu/examples/hnsw-ex.rs +++ b/iris-mpc-cpu/examples/hnsw-ex.rs @@ -1,7 +1,7 @@ use aes_prng::AesRng; use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; -use iris_mpc_cpu::hawkers::{iris_searcher::IrisSearcher, plaintext_store::PlaintextStore}; +use iris_mpc_cpu::hawkers::{iris_searcher::HnswSearcher, plaintext_store::PlaintextStore}; use rand::SeedableRng; const DATABASE_SIZE: usize = 1_000; @@ -16,7 +16,7 @@ fn main() { let mut rng = AesRng::seed_from_u64(0_u64); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); for idx in 0..DATABASE_SIZE { let raw_query = IrisCode::random_rng(&mut rng); diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index 5dbcb4f7f..66b5682d8 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -1,4 +1,4 @@ -use super::{iris_searcher::IrisSearcher, plaintext_store::PlaintextStore}; +use super::{iris_searcher::HnswSearcher, plaintext_store::PlaintextStore}; use crate::{ database_generators::{generate_galois_iris_shares, GaloisRingSharedIris}, execution::{ @@ -546,7 +546,7 @@ impl LocalNetAby3NgStoreProtocol { .collect::>(); jobs.spawn(async move { let mut graph_store = GraphMem::new(); - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); // insert queries for query in queries.iter() { searcher @@ -587,7 +587,7 @@ impl LocalNetAby3NgStoreProtocol { mod tests { use super::*; use crate::{ - database_generators::generate_galois_iris_shares, hawkers::iris_searcher::IrisSearcher, + database_generators::generate_galois_iris_shares, hawkers::iris_searcher::HnswSearcher, }; use aes_prng::AesRng; use hawk_pack::graph_store::GraphMem; @@ -619,7 +619,7 @@ mod tests { let mut rng = rng.clone(); jobs.spawn(async move { let mut aby3_graph = GraphMem::new(); - let db = IrisSearcher::default(); + let db = HnswSearcher::default(); let mut inserted = vec![]; // insert queries @@ -679,7 +679,7 @@ mod tests { { assert_eq!(v_from_scratch.storage.points, premade_v.storage.points); } - let hawk_searcher = IrisSearcher::default(); + let hawk_searcher = HnswSearcher::default(); for i in 0..database_size { let cleartext_neighbors = hawk_searcher @@ -826,7 +826,7 @@ mod tests { async fn test_gr_scratch_hnsw() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( &mut rng, database_size, diff --git a/iris-mpc-cpu/src/hawkers/iris_searcher.rs b/iris-mpc-cpu/src/hawkers/iris_searcher.rs index 9dfea67ae..e375bf26d 100644 --- a/iris-mpc-cpu/src/hawkers/iris_searcher.rs +++ b/iris-mpc-cpu/src/hawkers/iris_searcher.rs @@ -139,22 +139,22 @@ impl HnswParams { /// Operations on vectors are delegated to a VectorStore. /// Operations on the graph are delegate to a GraphStore. #[derive(Clone, Serialize, Deserialize)] -pub struct IrisSearcher { +pub struct HnswSearcher { pub params: HnswParams, } // TODO remove default value; this varies too much between applications // to make sense to specify something "obvious" -impl Default for IrisSearcher { +impl Default for HnswSearcher { fn default() -> Self { - IrisSearcher { + HnswSearcher { params: HnswParams::new(64, 32, 32), } } } #[allow(non_snake_case)] -impl IrisSearcher { +impl HnswSearcher { async fn connect_bidir>( &self, vector_store: &mut V, @@ -450,7 +450,7 @@ mod tests { let vector_store = &mut LazyMemoryStore::new(); let graph_store = &mut GraphMem::new(); let rng = &mut AesRng::seed_from_u64(0_u64); - let db = IrisSearcher::default(); + let db = HnswSearcher::default(); let queries1 = (0..100) .map(|raw_query| vector_store.prepare_query(raw_query)) diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 15ba41275..8278ce983 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,4 +1,4 @@ -use crate::hawkers::iris_searcher::IrisSearcher; +use crate::hawkers::iris_searcher::HnswSearcher; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ @@ -149,7 +149,7 @@ impl PlaintextStore { let mut plaintext_vector_store = PlaintextStore::default(); let mut plaintext_graph_store = GraphMem::new(); - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); for raw_query in cleartext_database.iter() { let query = plaintext_vector_store.prepare_query(raw_query.clone()); @@ -190,7 +190,7 @@ impl PlaintextStore { let mut rng_searcher1 = AesRng::from_rng(rng.clone())?; let mut plaintext_graph_store = GraphMem::new(); - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); for i in 0..graph_size { searcher @@ -210,7 +210,7 @@ impl PlaintextStore { #[cfg(test)] mod tests { use super::*; - use crate::hawkers::iris_searcher::IrisSearcher; + use crate::hawkers::iris_searcher::HnswSearcher; use aes_prng::AesRng; use iris_mpc_common::iris_db::db::IrisDB; use rand::SeedableRng; @@ -293,7 +293,7 @@ mod tests { async fn test_plaintext_hnsw_matcher() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 1; - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let (mut ptxt_vector, mut ptxt_graph) = PlaintextStore::create_random(&mut rng, database_size) .await diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index ce9be679d..3b8e5da54 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -340,7 +340,7 @@ mod tests { use super::*; use crate::{ execution::{local::generate_local_identities, player::Role}, - hawkers::{galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::IrisSearcher}, + hawkers::{galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::HnswSearcher}, }; use aes_prng::AesRng; use rand::SeedableRng; @@ -569,7 +569,7 @@ mod tests { async fn test_hnsw_local() { let mut rng = AesRng::seed_from_u64(0_u64); let database_size = 2; - let searcher = IrisSearcher::default(); + let searcher = HnswSearcher::default(); let mut vectors_and_graphs = LocalNetAby3NgStoreProtocol::shared_random_setup( &mut rng, database_size, diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs index da18a10c3..ce17ac649 100644 --- a/iris-mpc-cpu/src/py_bindings/hnsw.rs +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -1,6 +1,6 @@ use super::plaintext_store::Base64IrisCode; use crate::hawkers::{ - iris_searcher::IrisSearcher, + iris_searcher::HnswSearcher, plaintext_store::{PlaintextStore, PointId}, }; use hawk_pack::graph_store::GraphMem; @@ -11,7 +11,7 @@ use std::{fs::File, io::BufReader}; pub fn search( query: IrisCode, - searcher: &IrisSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> (PointId, f64) { @@ -31,7 +31,7 @@ pub fn search( // TODO could instead take iterator of IrisCodes to make more flexible pub fn insert( iris: IrisCode, - searcher: &IrisSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -49,7 +49,7 @@ pub fn insert( } pub fn insert_uniform_random( - searcher: &IrisSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) -> PointId { @@ -61,7 +61,7 @@ pub fn insert_uniform_random( pub fn fill_uniform_random( num: usize, - searcher: &IrisSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { @@ -87,7 +87,7 @@ pub fn fill_uniform_random( pub fn fill_from_ndjson_file( filename: &str, limit: Option, - searcher: &IrisSearcher, + searcher: &HnswSearcher, vector: &mut PlaintextStore, graph: &mut GraphMem, ) { diff --git a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs index c31a711cb..36212dcac 100644 --- a/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs +++ b/iris-mpc-py/src/py_hnsw/pyclasses/hawk_searcher.rs @@ -1,13 +1,13 @@ use super::{graph_store::PyGraphStore, iris_code::PyIrisCode, plaintext_store::PyPlaintextStore}; use iris_mpc_cpu::{ - hawkers::iris_searcher::{HnswParams, IrisSearcher, N_PARAM_LAYERS}, + hawkers::iris_searcher::{HnswParams, HnswSearcher, N_PARAM_LAYERS}, py_bindings, }; use pyo3::{exceptions::PyIOError, prelude::*}; #[pyclass] #[derive(Clone, Default)] -pub struct PyHawkSearcher(pub IrisSearcher); +pub struct PyHawkSearcher(pub HnswSearcher); #[pymethods] #[allow(non_snake_case)] @@ -20,13 +20,13 @@ impl PyHawkSearcher { #[staticmethod] pub fn new_standard(M: usize, ef_constr: usize, ef_search: usize) -> Self { let params = HnswParams::new(ef_constr, ef_search, M); - Self(IrisSearcher { params }) + Self(HnswSearcher { params }) } #[staticmethod] pub fn new_uniform(M: usize, ef: usize) -> Self { let params = HnswParams::new_uniform(ef, M); - Self(IrisSearcher { params }) + Self(HnswSearcher { params }) } /// Construct `IrisSearcher` with fully general parameters, specifying the @@ -49,7 +49,7 @@ impl PyHawkSearcher { ef_search, layer_probability, }; - Self(IrisSearcher { params }) + Self(HnswSearcher { params }) } pub fn insert( From 7432ccaae5fa46dc0e4879190a07661d31b14edb Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Fri, 13 Dec 2024 10:52:08 -0700 Subject: [PATCH 3/7] WIP initial algorithm metrics instrumentation and analysis --- iris-mpc-cpu/Cargo.toml | 4 + iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs | 148 ++++++++++++++++++++ iris-mpc-cpu/src/hawkers/iris_searcher.rs | 71 ++++++++++ iris-mpc-cpu/src/hawkers/plaintext_store.rs | 4 +- 4 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 69231b57b..6389a7c54 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -56,6 +56,10 @@ name = "hnsw-ex" name = "local_hnsw" path = "bin/local_hnsw.rs" +[[bin]] +name = "hnsw_algorithm_metrics" +path = "bin/hnsw_algorithm_metrics.rs" + [[bin]] name = "generate_benchmark_data" path = "bin/generate_benchmark_data.rs" diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs new file mode 100644 index 000000000..9a2792b71 --- /dev/null +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -0,0 +1,148 @@ +use aes_prng::AesRng; +use clap::Parser; +use hawk_pack::graph_store::GraphMem; +use iris_mpc_common::iris_db::iris::IrisCode; +use iris_mpc_cpu::hawkers::{ + iris_searcher::{tracing::{EventCounter, HnswEventCounterLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT}, HnswParams, HnswSearcher}, + plaintext_store::PlaintextStore}; +use rand::SeedableRng; +use std::{error::Error, sync::Arc}; + +use tracing_subscriber::prelude::*; + +#[derive(Parser)] +#[allow(non_snake_case)] +struct Args { + #[clap(default_value = "64")] + M: usize, + #[clap(default_value = "128")] + ef_constr: usize, + #[clap(default_value = "64")] + ef_search: usize, + #[clap(default_value = "10000")] + database_size: usize, +} + +#[tokio::main] +#[allow(non_snake_case)] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + let M = args.M; + let ef_constr = args.ef_constr; + let ef_search = args.ef_search; + let database_size = args.database_size; + + let counters = configure_tracing(); + + let mut rng = AesRng::seed_from_u64(0_u64); + let mut vector = PlaintextStore::default(); + let mut graph = GraphMem::new(); + let searcher = HnswSearcher { + params: HnswParams::new(M, ef_constr, ef_search) + }; + + for idx in 0..database_size { + let raw_query = IrisCode::random_rng(&mut rng); + let query = vector.prepare_query(raw_query.clone()); + searcher + .insert(&mut vector, &mut graph, &query, &mut rng) + .await; + + if idx % 1000 == 999 { + print!("{}, ", idx + 1); + print_stats(&counters, false); + } + } + + println!("Final counts:"); + print_stats(&counters, true); + + Ok(()) +} + +fn print_stats(counters: &Arc, verbose: bool) { + let layer_searches = counters.counters.get(LAYER_SEARCH_EVENT as usize).unwrap(); + let opened_nodes = counters.counters.get(OPEN_NODE_EVENT as usize).unwrap(); + let distance_evals = counters.counters.get(EVAL_DIST_EVENT as usize).unwrap(); + let distance_comps = counters.counters.get(COMPARE_DIST_EVENT as usize).unwrap(); + + if verbose { + println!(" Layer search events: {:?}", layer_searches); + println!(" Open node events: {:?}", opened_nodes); + println!(" Evaluate distance events: {:?}", distance_evals); + println!(" Compare distance events: {:?}", distance_comps); + } else { + println!("{:?}, {:?}, {:?}", opened_nodes, distance_evals, distance_comps); + } +} + +fn configure_tracing() -> Arc { + let counters = Arc::new(EventCounter::default()); + + let layer = HnswEventCounterLayer { + counters: counters.clone(), + }; + // Set up how `tracing-subscriber` will deal with tracing data. + tracing_subscriber::registry().with(layer).init(); + + counters +} + +// mod custom_layer { +// use tracing_subscriber::Layer; + +// pub struct CustomLayer; + +// impl Layer for CustomLayer where S: tracing::Subscriber { +// fn on_event( +// &self, +// event: &tracing::Event<'_>, +// _ctx: tracing_subscriber::layer::Context<'_, S>, +// ) { +// println!("Got event!"); +// println!(" level={:?}", event.metadata().level()); +// println!(" target={:?}", event.metadata().target()); +// println!(" name={:?}", event.metadata().name()); +// let mut visitor = PrintlnVisitor; +// event.record(&mut visitor); +// } +// } + +// struct PrintlnVisitor; + +// impl tracing::field::Visit for PrintlnVisitor { +// fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_str(&mut self, field: &tracing::field::Field, value: &str) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_error( +// &mut self, +// field: &tracing::field::Field, +// value: &(dyn std::error::Error + 'static), +// ) { +// println!(" field={} value={}", field.name(), value) +// } + +// fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { +// println!(" field={} value={:?}", field.name(), value) +// } +// } + +// } +// use custom_layer::CustomLayer; \ No newline at end of file diff --git a/iris-mpc-cpu/src/hawkers/iris_searcher.rs b/iris-mpc-cpu/src/hawkers/iris_searcher.rs index e375bf26d..a87fda735 100644 --- a/iris-mpc-cpu/src/hawkers/iris_searcher.rs +++ b/iris-mpc-cpu/src/hawkers/iris_searcher.rs @@ -11,6 +11,7 @@ use hawk_pack::{GraphStore, VectorStore}; use rand::RngCore; use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; +use ::tracing::info; use std::collections::HashSet; // specify construction and search parameters by layer up to this value minus 1 @@ -226,6 +227,8 @@ impl HnswSearcher { ef: usize, lc: usize, ) { + info!(event_type = tracing::LAYER_SEARCH_EVENT); + // v: The set of already visited vectors. let mut v = HashSet::::from_iter(W.iter().map(|(e, _eq)| e.clone())); @@ -244,6 +247,9 @@ impl HnswSearcher { break; } + // Open the node c and explore its neighbors. + info!(event_type = tracing::OPEN_NODE_EVENT); + // Visit all neighbors of c. let c_links = graph_store.get_links(&c, lc).await; @@ -435,6 +441,71 @@ impl HnswSearcher { } } +pub mod tracing { + use std::sync::{atomic::{AtomicUsize, Ordering}, Arc}; + + use tracing::{Event, Subscriber}; + use tracing_subscriber::{layer::Context, Layer}; + + pub const LAYER_SEARCH_EVENT: u64 = 0; + pub const OPEN_NODE_EVENT: u64 = 1; + pub const EVAL_DIST_EVENT: u64 = 2; + pub const COMPARE_DIST_EVENT: u64 = 3; + + const NUM_EVENT_TYPES: usize = 4; + + #[derive(Default)] + pub struct EventCounter { + pub counters: [AtomicUsize; NUM_EVENT_TYPES], + } + + pub struct HnswEventCounterLayer { + pub counters: Arc, + } + + impl Layer for HnswEventCounterLayer { + fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { + let mut visitor = EventVisitor::default(); + event.record(&mut visitor); + + if let Some(event_type) = visitor.event { + if let Some(counter) = self.counters.counters.get(event_type) { + let increment_amount = visitor.amount.unwrap_or(1); + counter.fetch_add(increment_amount, Ordering::Relaxed); + } else { + panic!("Invalid event type specified: {:?}", event_type); + } + } + } + } + + + #[derive(Default)] + struct EventVisitor { + // which event was encountered + event: Option, + + // how much to increment the associated counter + amount: Option, + } + + impl tracing::field::Visit for EventVisitor { + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + match field.name() { + "event_type" => { + self.event = Some(value as usize); + }, + "increment_amount" => { + self.amount = Some(value as usize); + }, + _ => {} + } + } + + fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn std::fmt::Debug) {} + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 8278ce983..abdc78956 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,4 +1,4 @@ -use crate::hawkers::iris_searcher::HnswSearcher; +use crate::hawkers::iris_searcher::{tracing::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, HnswSearcher}; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ @@ -117,6 +117,7 @@ impl VectorStore for PlaintextStore { query: &Self::QueryRef, vector: &Self::VectorRef, ) -> Self::DistanceRef { + tracing::info!(event_type = EVAL_DIST_EVENT); let query_code = &self.points[*query]; let vector_code = &self.points[*vector]; query_code.data.distance_fraction(&vector_code.data) @@ -132,6 +133,7 @@ impl VectorStore for PlaintextStore { distance1: &Self::DistanceRef, distance2: &Self::DistanceRef, ) -> bool { + tracing::info!(event_type = COMPARE_DIST_EVENT); let (a, b) = *distance1; // a/b let (c, d) = *distance2; // c/d (a as i32) * (d as i32) - (b as i32) * (c as i32) < 0 From ed4364e4c201a019ed4586a77844c988203184cb Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Fri, 13 Dec 2024 13:27:44 -0700 Subject: [PATCH 4/7] Move HnswSearcher to its own module --- iris-mpc-cpu/benches/hnsw.rs | 6 +- iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs | 86 ++++--------------- iris-mpc-cpu/examples/hnsw-ex.rs | 2 +- iris-mpc-cpu/src/hawkers/galois_store.rs | 4 +- iris-mpc-cpu/src/hawkers/mod.rs | 1 - iris-mpc-cpu/src/hawkers/plaintext_store.rs | 7 +- iris-mpc-cpu/src/hnsw/mod.rs | 3 + .../iris_searcher.rs => hnsw/searcher.rs} | 13 +-- iris-mpc-cpu/src/lib.rs | 1 + iris-mpc-cpu/src/network/grpc.rs | 3 +- iris-mpc-cpu/src/py_bindings/hnsw.rs | 6 +- 11 files changed, 44 insertions(+), 88 deletions(-) create mode 100644 iris-mpc-cpu/src/hnsw/mod.rs rename iris-mpc-cpu/src/{hawkers/iris_searcher.rs => hnsw/searcher.rs} (99%) diff --git a/iris-mpc-cpu/benches/hnsw.rs b/iris-mpc-cpu/benches/hnsw.rs index bf4d3cb10..b7f01a490 100644 --- a/iris-mpc-cpu/benches/hnsw.rs +++ b/iris-mpc-cpu/benches/hnsw.rs @@ -5,10 +5,8 @@ use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode}; use iris_mpc_cpu::{ database_generators::{create_random_sharing, generate_galois_iris_shares}, execution::local::LocalRuntime, - hawkers::{ - galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::HnswSearcher, - plaintext_store::PlaintextStore, - }, + hawkers::{galois_store::LocalNetAby3NgStoreProtocol, plaintext_store::PlaintextStore}, + hnsw::searcher::HnswSearcher, protocol::ops::{ batch_signed_lift_vec, cross_compare, galois_ring_pairwise_distance, galois_ring_to_rep3, }, diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs index 9a2792b71..9170c8ea2 100644 --- a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -2,23 +2,29 @@ use aes_prng::AesRng; use clap::Parser; use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; -use iris_mpc_cpu::hawkers::{ - iris_searcher::{tracing::{EventCounter, HnswEventCounterLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT}, HnswParams, HnswSearcher}, - plaintext_store::PlaintextStore}; +use iris_mpc_cpu::{ + hawkers::plaintext_store::PlaintextStore, + hnsw::searcher::{ + tracing::{ + EventCounter, HnswEventCounterLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, + LAYER_SEARCH_EVENT, OPEN_NODE_EVENT, + }, + HnswParams, HnswSearcher, + }, +}; use rand::SeedableRng; use std::{error::Error, sync::Arc}; - use tracing_subscriber::prelude::*; #[derive(Parser)] #[allow(non_snake_case)] struct Args { #[clap(default_value = "64")] - M: usize, + M: usize, #[clap(default_value = "128")] - ef_constr: usize, + ef_constr: usize, #[clap(default_value = "64")] - ef_search: usize, + ef_search: usize, #[clap(default_value = "10000")] database_size: usize, } @@ -38,7 +44,7 @@ async fn main() -> Result<(), Box> { let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); let searcher = HnswSearcher { - params: HnswParams::new(M, ef_constr, ef_search) + params: HnswParams::new(M, ef_constr, ef_search), }; for idx in 0..database_size { @@ -72,7 +78,10 @@ fn print_stats(counters: &Arc, verbose: bool) { println!(" Evaluate distance events: {:?}", distance_evals); println!(" Compare distance events: {:?}", distance_comps); } else { - println!("{:?}, {:?}, {:?}", opened_nodes, distance_evals, distance_comps); + println!( + "{:?}, {:?}, {:?}", + opened_nodes, distance_evals, distance_comps + ); } } @@ -87,62 +96,3 @@ fn configure_tracing() -> Arc { counters } - -// mod custom_layer { -// use tracing_subscriber::Layer; - -// pub struct CustomLayer; - -// impl Layer for CustomLayer where S: tracing::Subscriber { -// fn on_event( -// &self, -// event: &tracing::Event<'_>, -// _ctx: tracing_subscriber::layer::Context<'_, S>, -// ) { -// println!("Got event!"); -// println!(" level={:?}", event.metadata().level()); -// println!(" target={:?}", event.metadata().target()); -// println!(" name={:?}", event.metadata().name()); -// let mut visitor = PrintlnVisitor; -// event.record(&mut visitor); -// } -// } - -// struct PrintlnVisitor; - -// impl tracing::field::Visit for PrintlnVisitor { -// fn record_f64(&mut self, field: &tracing::field::Field, value: f64) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_i64(&mut self, field: &tracing::field::Field, value: i64) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_bool(&mut self, field: &tracing::field::Field, value: bool) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_str(&mut self, field: &tracing::field::Field, value: &str) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_error( -// &mut self, -// field: &tracing::field::Field, -// value: &(dyn std::error::Error + 'static), -// ) { -// println!(" field={} value={}", field.name(), value) -// } - -// fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) { -// println!(" field={} value={:?}", field.name(), value) -// } -// } - -// } -// use custom_layer::CustomLayer; \ No newline at end of file diff --git a/iris-mpc-cpu/examples/hnsw-ex.rs b/iris-mpc-cpu/examples/hnsw-ex.rs index 1295869df..1b179a560 100644 --- a/iris-mpc-cpu/examples/hnsw-ex.rs +++ b/iris-mpc-cpu/examples/hnsw-ex.rs @@ -1,7 +1,7 @@ use aes_prng::AesRng; use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; -use iris_mpc_cpu::hawkers::{iris_searcher::HnswSearcher, plaintext_store::PlaintextStore}; +use iris_mpc_cpu::{hawkers::plaintext_store::PlaintextStore, hnsw::searcher::HnswSearcher}; use rand::SeedableRng; const DATABASE_SIZE: usize = 1_000; diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index 66b5682d8..4f01c96b4 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -1,4 +1,3 @@ -use super::{iris_searcher::HnswSearcher, plaintext_store::PlaintextStore}; use crate::{ database_generators::{generate_galois_iris_shares, GaloisRingSharedIris}, execution::{ @@ -6,7 +5,8 @@ use crate::{ player::Identity, session::Session, }, - hawkers::plaintext_store::PointId, + hawkers::plaintext_store::{PlaintextStore, PointId}, + hnsw::HnswSearcher, network::NetworkType, protocol::ops::{ batch_signed_lift_vec, compare_threshold_and_open, cross_compare, diff --git a/iris-mpc-cpu/src/hawkers/mod.rs b/iris-mpc-cpu/src/hawkers/mod.rs index 0a1bb4e52..e2ec49a26 100644 --- a/iris-mpc-cpu/src/hawkers/mod.rs +++ b/iris-mpc-cpu/src/hawkers/mod.rs @@ -1,3 +1,2 @@ pub mod galois_store; -pub mod iris_searcher; pub mod plaintext_store; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index abdc78956..d6c5d8a3e 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,4 +1,7 @@ -use crate::hawkers::iris_searcher::{tracing::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, HnswSearcher}; +use crate::hnsw::searcher::{ + tracing::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, + HnswSearcher, +}; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ @@ -212,7 +215,7 @@ impl PlaintextStore { #[cfg(test)] mod tests { use super::*; - use crate::hawkers::iris_searcher::HnswSearcher; + use crate::hnsw::HnswSearcher; use aes_prng::AesRng; use iris_mpc_common::iris_db::db::IrisDB; use rand::SeedableRng; diff --git a/iris-mpc-cpu/src/hnsw/mod.rs b/iris-mpc-cpu/src/hnsw/mod.rs new file mode 100644 index 000000000..8cc4cecef --- /dev/null +++ b/iris-mpc-cpu/src/hnsw/mod.rs @@ -0,0 +1,3 @@ +pub mod searcher; + +pub use searcher::HnswSearcher; diff --git a/iris-mpc-cpu/src/hawkers/iris_searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs similarity index 99% rename from iris-mpc-cpu/src/hawkers/iris_searcher.rs rename to iris-mpc-cpu/src/hnsw/searcher.rs index a87fda735..253803639 100644 --- a/iris-mpc-cpu/src/hawkers/iris_searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -4,6 +4,7 @@ //* //* https://github.com/Inversed-Tech/hawk-pack/ +use ::tracing::info; pub use hawk_pack::data_structures::queue::{ FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, }; @@ -11,7 +12,6 @@ use hawk_pack::{GraphStore, VectorStore}; use rand::RngCore; use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; -use ::tracing::info; use std::collections::HashSet; // specify construction and search parameters by layer up to this value minus 1 @@ -442,8 +442,10 @@ impl HnswSearcher { } pub mod tracing { - use std::sync::{atomic::{AtomicUsize, Ordering}, Arc}; - + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }; use tracing::{Event, Subscriber}; use tracing_subscriber::{layer::Context, Layer}; @@ -479,7 +481,6 @@ pub mod tracing { } } - #[derive(Default)] struct EventVisitor { // which event was encountered @@ -494,10 +495,10 @@ pub mod tracing { match field.name() { "event_type" => { self.event = Some(value as usize); - }, + } "increment_amount" => { self.amount = Some(value as usize); - }, + } _ => {} } } diff --git a/iris-mpc-cpu/src/lib.rs b/iris-mpc-cpu/src/lib.rs index bf4a96011..94d4eaf6b 100644 --- a/iris-mpc-cpu/src/lib.rs +++ b/iris-mpc-cpu/src/lib.rs @@ -1,6 +1,7 @@ pub mod database_generators; pub mod execution; pub mod hawkers; +pub mod hnsw; pub(crate) mod network; #[rustfmt::skip] pub(crate) mod proto_generated; diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index 3b8e5da54..57c6ebb8f 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -340,7 +340,8 @@ mod tests { use super::*; use crate::{ execution::{local::generate_local_identities, player::Role}, - hawkers::{galois_store::LocalNetAby3NgStoreProtocol, iris_searcher::HnswSearcher}, + hawkers::galois_store::LocalNetAby3NgStoreProtocol, + hnsw::HnswSearcher, }; use aes_prng::AesRng; use rand::SeedableRng; diff --git a/iris-mpc-cpu/src/py_bindings/hnsw.rs b/iris-mpc-cpu/src/py_bindings/hnsw.rs index ce17ac649..cf22767a3 100644 --- a/iris-mpc-cpu/src/py_bindings/hnsw.rs +++ b/iris-mpc-cpu/src/py_bindings/hnsw.rs @@ -1,7 +1,7 @@ use super::plaintext_store::Base64IrisCode; -use crate::hawkers::{ - iris_searcher::HnswSearcher, - plaintext_store::{PlaintextStore, PointId}, +use crate::{ + hawkers::plaintext_store::{PlaintextStore, PointId}, + hnsw::HnswSearcher, }; use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; From 2c0bb8ea830a2c33c852866de9c47ce1dd30b618 Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Fri, 13 Dec 2024 14:54:39 -0700 Subject: [PATCH 5/7] Move HNSW metrics to new submodule --- iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs | 8 +-- iris-mpc-cpu/src/hawkers/plaintext_store.rs | 4 +- iris-mpc-cpu/src/hnsw/metrics.rs | 63 ++++++++++++++++++ iris-mpc-cpu/src/hnsw/mod.rs | 1 + iris-mpc-cpu/src/hnsw/searcher.rs | 74 ++------------------- 5 files changed, 76 insertions(+), 74 deletions(-) create mode 100644 iris-mpc-cpu/src/hnsw/metrics.rs diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs index 9170c8ea2..46d1f3042 100644 --- a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -4,13 +4,13 @@ use hawk_pack::graph_store::GraphMem; use iris_mpc_common::iris_db::iris::IrisCode; use iris_mpc_cpu::{ hawkers::plaintext_store::PlaintextStore, - hnsw::searcher::{ - tracing::{ + hnsw::{ + searcher::{HnswParams, HnswSearcher}, + metrics::{ EventCounter, HnswEventCounterLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT, }, - HnswParams, HnswSearcher, - }, + } }; use rand::SeedableRng; use std::{error::Error, sync::Arc}; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index d6c5d8a3e..fa04b12f2 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,6 +1,6 @@ -use crate::hnsw::searcher::{ - tracing::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, +use crate::hnsw::{ HnswSearcher, + metrics::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, }; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; diff --git a/iris-mpc-cpu/src/hnsw/metrics.rs b/iris-mpc-cpu/src/hnsw/metrics.rs new file mode 100644 index 000000000..5534b5958 --- /dev/null +++ b/iris-mpc-cpu/src/hnsw/metrics.rs @@ -0,0 +1,63 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tracing::{Event, Subscriber}; +use tracing_subscriber::{layer::Context, Layer}; + +pub const LAYER_SEARCH_EVENT: u64 = 0; +pub const OPEN_NODE_EVENT: u64 = 1; +pub const EVAL_DIST_EVENT: u64 = 2; +pub const COMPARE_DIST_EVENT: u64 = 3; + +const NUM_EVENT_TYPES: usize = 4; + +#[derive(Default)] +pub struct EventCounter { + pub counters: [AtomicUsize; NUM_EVENT_TYPES], +} + +pub struct HnswEventCounterLayer { + pub counters: Arc, +} + +impl Layer for HnswEventCounterLayer { + fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { + let mut visitor = EventVisitor::default(); + event.record(&mut visitor); + + if let Some(event_type) = visitor.event { + if let Some(counter) = self.counters.counters.get(event_type) { + let increment_amount = visitor.amount.unwrap_or(1); + counter.fetch_add(increment_amount, Ordering::Relaxed); + } else { + panic!("Invalid event type specified: {:?}", event_type); + } + } + } +} + +#[derive(Default)] +struct EventVisitor { + // which event was encountered + event: Option, + + // how much to increment the associated counter + amount: Option, +} + +impl tracing::field::Visit for EventVisitor { + fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { + match field.name() { + "event_type" => { + self.event = Some(value as usize); + } + "increment_amount" => { + self.amount = Some(value as usize); + } + _ => {} + } + } + + fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn std::fmt::Debug) {} +} diff --git a/iris-mpc-cpu/src/hnsw/mod.rs b/iris-mpc-cpu/src/hnsw/mod.rs index 8cc4cecef..9a6c320ef 100644 --- a/iris-mpc-cpu/src/hnsw/mod.rs +++ b/iris-mpc-cpu/src/hnsw/mod.rs @@ -1,3 +1,4 @@ pub mod searcher; +pub mod metrics; pub use searcher::HnswSearcher; diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs index 253803639..9bf36c8e0 100644 --- a/iris-mpc-cpu/src/hnsw/searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -4,7 +4,7 @@ //* //* https://github.com/Inversed-Tech/hawk-pack/ -use ::tracing::info; +use tracing::{info, instrument}; pub use hawk_pack::data_structures::queue::{ FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, }; @@ -14,6 +14,8 @@ use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +use super::metrics; + // specify construction and search parameters by layer up to this value minus 1 // any higher layers will use the last set of parameters pub const N_PARAM_LAYERS: usize = 5; @@ -217,6 +219,7 @@ impl HnswSearcher { /// given layer using depth-first graph traversal, Terminates when `W` /// contains vectors which are the nearest to `q` among all traversed /// vertices and their neighbors. + #[instrument(skip(self, vector_store, graph_store))] #[allow(non_snake_case)] async fn search_layer>( &self, @@ -227,7 +230,7 @@ impl HnswSearcher { ef: usize, lc: usize, ) { - info!(event_type = tracing::LAYER_SEARCH_EVENT); + info!(event_type = metrics::LAYER_SEARCH_EVENT); // v: The set of already visited vectors. let mut v = HashSet::::from_iter(W.iter().map(|(e, _eq)| e.clone())); @@ -248,7 +251,7 @@ impl HnswSearcher { } // Open the node c and explore its neighbors. - info!(event_type = tracing::OPEN_NODE_EVENT); + info!(event_type = metrics::OPEN_NODE_EVENT); // Visit all neighbors of c. let c_links = graph_store.get_links(&c, lc).await; @@ -441,71 +444,6 @@ impl HnswSearcher { } } -pub mod tracing { - use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }; - use tracing::{Event, Subscriber}; - use tracing_subscriber::{layer::Context, Layer}; - - pub const LAYER_SEARCH_EVENT: u64 = 0; - pub const OPEN_NODE_EVENT: u64 = 1; - pub const EVAL_DIST_EVENT: u64 = 2; - pub const COMPARE_DIST_EVENT: u64 = 3; - - const NUM_EVENT_TYPES: usize = 4; - - #[derive(Default)] - pub struct EventCounter { - pub counters: [AtomicUsize; NUM_EVENT_TYPES], - } - - pub struct HnswEventCounterLayer { - pub counters: Arc, - } - - impl Layer for HnswEventCounterLayer { - fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { - let mut visitor = EventVisitor::default(); - event.record(&mut visitor); - - if let Some(event_type) = visitor.event { - if let Some(counter) = self.counters.counters.get(event_type) { - let increment_amount = visitor.amount.unwrap_or(1); - counter.fetch_add(increment_amount, Ordering::Relaxed); - } else { - panic!("Invalid event type specified: {:?}", event_type); - } - } - } - } - - #[derive(Default)] - struct EventVisitor { - // which event was encountered - event: Option, - - // how much to increment the associated counter - amount: Option, - } - - impl tracing::field::Visit for EventVisitor { - fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { - match field.name() { - "event_type" => { - self.event = Some(value as usize); - } - "increment_amount" => { - self.amount = Some(value as usize); - } - _ => {} - } - } - - fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn std::fmt::Debug) {} - } -} #[cfg(test)] mod tests { From fcea2b3b7a97b171c85ceae8f3320b2db67f44a1 Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Mon, 16 Dec 2024 15:12:54 -0700 Subject: [PATCH 6/7] Rough implementation of more specific layer search metrics tracing --- iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs | 50 +++++++--- iris-mpc-cpu/src/hawkers/galois_store.rs | 2 +- iris-mpc-cpu/src/hnsw/metrics.rs | 105 ++++++++++++++++++--- iris-mpc-cpu/src/hnsw/searcher.rs | 28 +++++- 4 files changed, 155 insertions(+), 30 deletions(-) diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs index 46d1f3042..dd13ff310 100644 --- a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -5,15 +5,13 @@ use iris_mpc_common::iris_db::iris::IrisCode; use iris_mpc_cpu::{ hawkers::plaintext_store::PlaintextStore, hnsw::{ - searcher::{HnswParams, HnswSearcher}, metrics::{ - EventCounter, HnswEventCounterLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, - LAYER_SEARCH_EVENT, OPEN_NODE_EVENT, - }, + EventCounter, HnswEventCounterLayer, VertexOpeningsLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT + }, searcher::{HnswParams, HnswSearcher} } }; use rand::SeedableRng; -use std::{error::Error, sync::Arc}; +use std::{collections::HashMap, error::Error, sync::{Arc, Mutex}}; use tracing_subscriber::prelude::*; #[derive(Parser)] @@ -27,6 +25,7 @@ struct Args { ef_search: usize, #[clap(default_value = "10000")] database_size: usize, + layer_probability: Option, } #[tokio::main] @@ -37,15 +36,20 @@ async fn main() -> Result<(), Box> { let ef_constr = args.ef_constr; let ef_search = args.ef_search; let database_size = args.database_size; + let layer_probability = args.layer_probability; - let counters = configure_tracing(); + let (counters, counter_map) = configure_tracing(); - let mut rng = AesRng::seed_from_u64(0_u64); + let mut rng = AesRng::seed_from_u64(42_u64); + // let mut rng = rand::thread_rng(); let mut vector = PlaintextStore::default(); let mut graph = GraphMem::new(); - let searcher = HnswSearcher { - params: HnswParams::new(M, ef_constr, ef_search), + let params = if let Some(p) = layer_probability { + HnswParams::new_with_layer_probability(ef_constr, ef_search, M, p) + } else { + HnswParams::new(ef_constr, ef_search, M) }; + let searcher = HnswSearcher { params }; for idx in 0..database_size { let raw_query = IrisCode::random_rng(&mut rng); @@ -63,6 +67,11 @@ async fn main() -> Result<(), Box> { println!("Final counts:"); print_stats(&counters, true); + println!("Layer search counts:"); + for ((lc, ef), value) in counter_map.lock().unwrap().iter() { + println!(" lc={lc},ef={ef}: {value}"); + } + Ok(()) } @@ -85,14 +94,27 @@ fn print_stats(counters: &Arc, verbose: bool) { } } -fn configure_tracing() -> Arc { +fn configure_tracing() -> (Arc, Arc>>) { let counters = Arc::new(EventCounter::default()); - let layer = HnswEventCounterLayer { + let counting_layer = HnswEventCounterLayer { counters: counters.clone(), }; - // Set up how `tracing-subscriber` will deal with tracing data. - tracing_subscriber::registry().with(layer).init(); - counters + let counter_map: Arc>> + = Arc::new(Mutex::new(HashMap::new())); + + let vertex_openings_layer = VertexOpeningsLayer { + counter_map: counter_map.clone() + }; + + tracing_subscriber::registry() + .with(counting_layer) + .with(vertex_openings_layer) + .init(); + + // tracing_subscriber::fmt() + // .init(); + + (counters, counter_map) } diff --git a/iris-mpc-cpu/src/hawkers/galois_store.rs b/iris-mpc-cpu/src/hawkers/galois_store.rs index 4f01c96b4..497beaa9f 100644 --- a/iris-mpc-cpu/src/hawkers/galois_store.rs +++ b/iris-mpc-cpu/src/hawkers/galois_store.rs @@ -587,7 +587,7 @@ impl LocalNetAby3NgStoreProtocol { mod tests { use super::*; use crate::{ - database_generators::generate_galois_iris_shares, hawkers::iris_searcher::HnswSearcher, + database_generators::generate_galois_iris_shares, hnsw::HnswSearcher, }; use aes_prng::AesRng; use hawk_pack::graph_store::GraphMem; diff --git a/iris-mpc-cpu/src/hnsw/metrics.rs b/iris-mpc-cpu/src/hnsw/metrics.rs index 5534b5958..cd829cefd 100644 --- a/iris-mpc-cpu/src/hnsw/metrics.rs +++ b/iris-mpc-cpu/src/hnsw/metrics.rs @@ -1,9 +1,15 @@ -use std::sync::{ +use std::{collections::HashMap, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, + Arc, Mutex, +}}; +use tracing::{field::{Field, Visit}, Event, Id, Subscriber}; +use tracing_subscriber::{ + layer::{Context, Layer}, }; -use tracing::{Event, Subscriber}; -use tracing_subscriber::{layer::Context, Layer}; +use std::fmt::Debug; +use tracing_subscriber::registry::LookupSpan; + +use tracing::span::Attributes; pub const LAYER_SEARCH_EVENT: u64 = 0; pub const OPEN_NODE_EVENT: u64 = 1; @@ -21,15 +27,17 @@ pub struct HnswEventCounterLayer { pub counters: Arc, } -impl Layer for HnswEventCounterLayer { +impl Layer for HnswEventCounterLayer + where S: Subscriber + for <'a> LookupSpan<'a> +{ fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { let mut visitor = EventVisitor::default(); event.record(&mut visitor); if let Some(event_type) = visitor.event { - if let Some(counter) = self.counters.counters.get(event_type) { + if let Some(counter) = self.counters.counters.get(event_type as usize) { let increment_amount = visitor.amount.unwrap_or(1); - counter.fetch_add(increment_amount, Ordering::Relaxed); + counter.fetch_add(increment_amount as usize, Ordering::Relaxed); } else { panic!("Invalid event type specified: {:?}", event_type); } @@ -40,24 +48,93 @@ impl Layer for HnswEventCounterLayer { #[derive(Default)] struct EventVisitor { // which event was encountered - event: Option, + event: Option, // how much to increment the associated counter - amount: Option, + amount: Option, } -impl tracing::field::Visit for EventVisitor { - fn record_u64(&mut self, field: &tracing::field::Field, value: u64) { +impl Visit for EventVisitor { + fn record_u64(&mut self, field: &Field, value: u64) { match field.name() { "event_type" => { - self.event = Some(value as usize); + self.event = Some(value); } "increment_amount" => { - self.amount = Some(value as usize); + self.amount = Some(value); } _ => {} } } - fn record_debug(&mut self, _field: &tracing::field::Field, _value: &dyn std::fmt::Debug) {} + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} +} + + +/// Tracing library Layer for counting detailed HNSW layer search operations +pub struct VertexOpeningsLayer { + // Measure number of vertex openings for different lc and ef values + pub counter_map: Arc>>, +} + +impl Layer for VertexOpeningsLayer + where S: Subscriber + for <'a> LookupSpan<'a> +{ + // fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> Interest { + // Interest::sometimes() + // } + + // fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool { + // let is_search_layer_span = metadata.is_span() && metadata.name() == "search_layer"; + // is_search_layer_span || metadata.is_event() + // } + + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).unwrap(); + let mut visitor = LayerSearchFields::default(); + attrs.record(&mut visitor); + span.extensions_mut().insert(visitor); + } + + fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) { + let mut visitor = EventVisitor::default(); + event.record(&mut visitor); + + if let Some(OPEN_NODE_EVENT) = visitor.event { + // open node event must have parent span representing open node function + let current_span = ctx.current_span(); + let span_id = current_span.id().unwrap(); + if let Some(LayerSearchFields { lc: Some(lc), ef: Some(ef) }) + = ctx.span(span_id).unwrap().extensions().get::() + { + let mut counter_map = self.counter_map.lock().unwrap(); + let increment_amount = visitor.amount.unwrap_or(1); + *counter_map.entry((*lc as usize, *ef as usize)).or_insert(0usize) + += increment_amount as usize; + } else { + panic!("Open node event is missing associated span fields"); + } + } + } +} + +#[derive(Default)] +pub struct LayerSearchFields { + lc: Option, + ef: Option, +} + +impl Visit for LayerSearchFields { + fn record_u64(&mut self, field: &Field, value: u64) { + match field.name() { + "lc" => { + self.lc = Some(value); + } + "ef" => { + self.ef = Some(value); + } + _ => {} + } + } + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} } diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs index 9bf36c8e0..25626ecdc 100644 --- a/iris-mpc-cpu/src/hnsw/searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -64,6 +64,32 @@ impl HnswParams { } } + /// Same as standard constructor but with an extra input for a non-standard + /// `layer_probability` parameter. + pub fn new_with_layer_probability( + ef_construction: usize, + ef_search: usize, + M: usize, + layer_probability: f64) -> Self + { + let M_arr = [M; N_PARAM_LAYERS]; + let mut M_max_arr = [M; N_PARAM_LAYERS]; + M_max_arr[0] = 2 * M; + let ef_constr_search_arr = [1usize; N_PARAM_LAYERS]; + let ef_constr_insert_arr = [ef_construction; N_PARAM_LAYERS]; + let mut ef_search_arr = [1usize; N_PARAM_LAYERS]; + ef_search_arr[0] = ef_search; + + Self { + M: M_arr, + M_max: M_max_arr, + ef_constr_search: ef_constr_search_arr, + ef_constr_insert: ef_constr_insert_arr, + ef_search: ef_search_arr, + layer_probability, + } + } + /// Parameter configuration using fixed exploration factor for all layer /// search operations, both for insertion and for search. pub fn new_uniform(ef: usize, M: usize) -> Self { @@ -219,7 +245,7 @@ impl HnswSearcher { /// given layer using depth-first graph traversal, Terminates when `W` /// contains vectors which are the nearest to `q` among all traversed /// vertices and their neighbors. - #[instrument(skip(self, vector_store, graph_store))] + #[instrument(skip(self, vector_store, graph_store, W))] #[allow(non_snake_case)] async fn search_layer>( &self, From 147f7b5b2bde62a4894d83c49bf86e23ce3f8e6c Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Thu, 9 Jan 2025 14:35:29 -0700 Subject: [PATCH 7/7] Formatting --- iris-mpc-cpu/Cargo.toml | 2 +- iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs | 33 ++++++++----- iris-mpc-cpu/src/hawkers/aby3_store.rs | 4 +- iris-mpc-cpu/src/hawkers/plaintext_store.rs | 2 +- iris-mpc-cpu/src/hnsw/metrics.rs | 53 +++++++++++++-------- iris-mpc-cpu/src/hnsw/mod.rs | 2 +- iris-mpc-cpu/src/hnsw/searcher.rs | 10 ++-- iris-mpc-cpu/src/network/grpc.rs | 2 +- 8 files changed, 63 insertions(+), 45 deletions(-) diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index 6389a7c54..c0e862807 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -26,6 +26,7 @@ itertools.workspace = true num-traits.workspace = true prost = "0.13" rand.workspace = true +rand_distr = "0.4.3" rstest = "0.23.0" serde.workspace = true serde_json.workspace = true @@ -37,7 +38,6 @@ tracing.workspace = true tracing-subscriber.workspace = true tracing-test = "0.2.5" uuid.workspace = true -rand_distr = "0.4.3" [dev-dependencies] criterion = { version = "0.5.1", features = ["async_tokio"] } diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs index dd13ff310..8e43a9c4f 100644 --- a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -6,25 +6,31 @@ use iris_mpc_cpu::{ hawkers::plaintext_store::PlaintextStore, hnsw::{ metrics::{ - EventCounter, HnswEventCounterLayer, VertexOpeningsLayer, COMPARE_DIST_EVENT, EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT - }, searcher::{HnswParams, HnswSearcher} - } + EventCounter, HnswEventCounterLayer, VertexOpeningsLayer, COMPARE_DIST_EVENT, + EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT, + }, + searcher::{HnswParams, HnswSearcher}, + }, }; use rand::SeedableRng; -use std::{collections::HashMap, error::Error, sync::{Arc, Mutex}}; +use std::{ + collections::HashMap, + error::Error, + sync::{Arc, Mutex}, +}; use tracing_subscriber::prelude::*; #[derive(Parser)] #[allow(non_snake_case)] struct Args { #[clap(default_value = "64")] - M: usize, + M: usize, #[clap(default_value = "128")] - ef_constr: usize, + ef_constr: usize, #[clap(default_value = "64")] - ef_search: usize, + ef_search: usize, #[clap(default_value = "10000")] - database_size: usize, + database_size: usize, layer_probability: Option, } @@ -94,18 +100,21 @@ fn print_stats(counters: &Arc, verbose: bool) { } } -fn configure_tracing() -> (Arc, Arc>>) { +fn configure_tracing() -> ( + Arc, + Arc>>, +) { let counters = Arc::new(EventCounter::default()); let counting_layer = HnswEventCounterLayer { counters: counters.clone(), }; - let counter_map: Arc>> - = Arc::new(Mutex::new(HashMap::new())); + let counter_map: Arc>> = + Arc::new(Mutex::new(HashMap::new())); let vertex_openings_layer = VertexOpeningsLayer { - counter_map: counter_map.clone() + counter_map: counter_map.clone(), }; tracing_subscriber::registry() diff --git a/iris-mpc-cpu/src/hawkers/aby3_store.rs b/iris-mpc-cpu/src/hawkers/aby3_store.rs index 20c82d800..b859a4f91 100644 --- a/iris-mpc-cpu/src/hawkers/aby3_store.rs +++ b/iris-mpc-cpu/src/hawkers/aby3_store.rs @@ -603,9 +603,7 @@ impl Aby3Store { #[cfg(test)] mod tests { use super::*; - use crate::{ - database_generators::generate_galois_iris_shares, hnsw::HnswSearcher, - }; + use crate::{database_generators::generate_galois_iris_shares, hnsw::HnswSearcher}; use aes_prng::AesRng; use hawk_pack::graph_store::GraphMem; use itertools::Itertools; diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index fa04b12f2..e03dc3e1b 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,6 +1,6 @@ use crate::hnsw::{ - HnswSearcher, metrics::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, + HnswSearcher, }; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; diff --git a/iris-mpc-cpu/src/hnsw/metrics.rs b/iris-mpc-cpu/src/hnsw/metrics.rs index cd829cefd..2c8cb4f86 100644 --- a/iris-mpc-cpu/src/hnsw/metrics.rs +++ b/iris-mpc-cpu/src/hnsw/metrics.rs @@ -1,15 +1,20 @@ -use std::{collections::HashMap, sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, -}}; -use tracing::{field::{Field, Visit}, Event, Id, Subscriber}; +use std::{ + collections::HashMap, + fmt::Debug, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, +}; +use tracing::{ + field::{Field, Visit}, + span::Attributes, + Event, Id, Subscriber, +}; use tracing_subscriber::{ layer::{Context, Layer}, + registry::LookupSpan, }; -use std::fmt::Debug; -use tracing_subscriber::registry::LookupSpan; - -use tracing::span::Attributes; pub const LAYER_SEARCH_EVENT: u64 = 0; pub const OPEN_NODE_EVENT: u64 = 1; @@ -28,7 +33,8 @@ pub struct HnswEventCounterLayer { } impl Layer for HnswEventCounterLayer - where S: Subscriber + for <'a> LookupSpan<'a> +where + S: Subscriber + for<'a> LookupSpan<'a>, { fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { let mut visitor = EventVisitor::default(); @@ -70,7 +76,6 @@ impl Visit for EventVisitor { fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} } - /// Tracing library Layer for counting detailed HNSW layer search operations pub struct VertexOpeningsLayer { // Measure number of vertex openings for different lc and ef values @@ -78,15 +83,16 @@ pub struct VertexOpeningsLayer { } impl Layer for VertexOpeningsLayer - where S: Subscriber + for <'a> LookupSpan<'a> +where + S: Subscriber + for<'a> LookupSpan<'a>, { - // fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> Interest { - // Interest::sometimes() + // fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> + // Interest { Interest::sometimes() // } // fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool { - // let is_search_layer_span = metadata.is_span() && metadata.name() == "search_layer"; - // is_search_layer_span || metadata.is_event() + // let is_search_layer_span = metadata.is_span() && metadata.name() == + // "search_layer"; is_search_layer_span || metadata.is_event() // } fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { @@ -104,13 +110,20 @@ impl Layer for VertexOpeningsLayer // open node event must have parent span representing open node function let current_span = ctx.current_span(); let span_id = current_span.id().unwrap(); - if let Some(LayerSearchFields { lc: Some(lc), ef: Some(ef) }) - = ctx.span(span_id).unwrap().extensions().get::() + if let Some(LayerSearchFields { + lc: Some(lc), + ef: Some(ef), + }) = ctx + .span(span_id) + .unwrap() + .extensions() + .get::() { let mut counter_map = self.counter_map.lock().unwrap(); let increment_amount = visitor.amount.unwrap_or(1); - *counter_map.entry((*lc as usize, *ef as usize)).or_insert(0usize) - += increment_amount as usize; + *counter_map + .entry((*lc as usize, *ef as usize)) + .or_insert(0usize) += increment_amount as usize; } else { panic!("Open node event is missing associated span fields"); } diff --git a/iris-mpc-cpu/src/hnsw/mod.rs b/iris-mpc-cpu/src/hnsw/mod.rs index 9a6c320ef..edcae0c1c 100644 --- a/iris-mpc-cpu/src/hnsw/mod.rs +++ b/iris-mpc-cpu/src/hnsw/mod.rs @@ -1,4 +1,4 @@ -pub mod searcher; pub mod metrics; +pub mod searcher; pub use searcher::HnswSearcher; diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs index 25626ecdc..2101604b9 100644 --- a/iris-mpc-cpu/src/hnsw/searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -4,7 +4,7 @@ //* //* https://github.com/Inversed-Tech/hawk-pack/ -use tracing::{info, instrument}; +use super::metrics; pub use hawk_pack::data_structures::queue::{ FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, }; @@ -13,8 +13,7 @@ use rand::RngCore; use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; - -use super::metrics; +use tracing::{info, instrument}; // specify construction and search parameters by layer up to this value minus 1 // any higher layers will use the last set of parameters @@ -70,8 +69,8 @@ impl HnswParams { ef_construction: usize, ef_search: usize, M: usize, - layer_probability: f64) -> Self - { + layer_probability: f64, + ) -> Self { let M_arr = [M; N_PARAM_LAYERS]; let mut M_max_arr = [M; N_PARAM_LAYERS]; M_max_arr[0] = 2 * M; @@ -470,7 +469,6 @@ impl HnswSearcher { } } - #[cfg(test)] mod tests { use super::*; diff --git a/iris-mpc-cpu/src/network/grpc.rs b/iris-mpc-cpu/src/network/grpc.rs index 74db712b4..1e1a0542a 100644 --- a/iris-mpc-cpu/src/network/grpc.rs +++ b/iris-mpc-cpu/src/network/grpc.rs @@ -340,8 +340,8 @@ mod tests { use super::*; use crate::{ execution::{local::generate_local_identities, player::Role}, - hnsw::HnswSearcher, hawkers::aby3_store::Aby3Store, + hnsw::HnswSearcher, }; use aes_prng::AesRng; use rand::SeedableRng;