diff --git a/iris-mpc-cpu/Cargo.toml b/iris-mpc-cpu/Cargo.toml index dcfdceb1e..c0e862807 100644 --- a/iris-mpc-cpu/Cargo.toml +++ b/iris-mpc-cpu/Cargo.toml @@ -56,6 +56,10 @@ name = "hnsw-ex" name = "local_hnsw" path = "bin/local_hnsw.rs" +[[bin]] +name = "hnsw_algorithm_metrics" +path = "bin/hnsw_algorithm_metrics.rs" + [[bin]] name = "generate_benchmark_data" path = "bin/generate_benchmark_data.rs" diff --git a/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs new file mode 100644 index 000000000..8e43a9c4f --- /dev/null +++ b/iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs @@ -0,0 +1,129 @@ +use aes_prng::AesRng; +use clap::Parser; +use hawk_pack::graph_store::GraphMem; +use iris_mpc_common::iris_db::iris::IrisCode; +use iris_mpc_cpu::{ + hawkers::plaintext_store::PlaintextStore, + hnsw::{ + metrics::{ + EventCounter, HnswEventCounterLayer, VertexOpeningsLayer, COMPARE_DIST_EVENT, + EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT, + }, + searcher::{HnswParams, HnswSearcher}, + }, +}; +use rand::SeedableRng; +use std::{ + collections::HashMap, + error::Error, + sync::{Arc, Mutex}, +}; +use tracing_subscriber::prelude::*; + +#[derive(Parser)] +#[allow(non_snake_case)] +struct Args { + #[clap(default_value = "64")] + M: usize, + #[clap(default_value = "128")] + ef_constr: usize, + #[clap(default_value = "64")] + ef_search: usize, + #[clap(default_value = "10000")] + database_size: usize, + layer_probability: Option, +} + +#[tokio::main] +#[allow(non_snake_case)] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + let M = args.M; + let ef_constr = args.ef_constr; + let ef_search = args.ef_search; + let database_size = args.database_size; + let layer_probability = args.layer_probability; + + let (counters, counter_map) = configure_tracing(); + + let mut rng = AesRng::seed_from_u64(42_u64); + // let mut rng = rand::thread_rng(); + let mut vector = PlaintextStore::default(); + let mut graph = GraphMem::new(); + let params = if let Some(p) = layer_probability { + HnswParams::new_with_layer_probability(ef_constr, ef_search, M, p) + } else { + HnswParams::new(ef_constr, ef_search, M) + }; + let searcher = HnswSearcher { params }; + + for idx in 0..database_size { + let raw_query = IrisCode::random_rng(&mut rng); + let query = vector.prepare_query(raw_query.clone()); + searcher + .insert(&mut vector, &mut graph, &query, &mut rng) + .await; + + if idx % 1000 == 999 { + print!("{}, ", idx + 1); + print_stats(&counters, false); + } + } + + println!("Final counts:"); + print_stats(&counters, true); + + println!("Layer search counts:"); + for ((lc, ef), value) in counter_map.lock().unwrap().iter() { + println!(" lc={lc},ef={ef}: {value}"); + } + + Ok(()) +} + +fn print_stats(counters: &Arc, verbose: bool) { + let layer_searches = counters.counters.get(LAYER_SEARCH_EVENT as usize).unwrap(); + let opened_nodes = counters.counters.get(OPEN_NODE_EVENT as usize).unwrap(); + let distance_evals = counters.counters.get(EVAL_DIST_EVENT as usize).unwrap(); + let distance_comps = counters.counters.get(COMPARE_DIST_EVENT as usize).unwrap(); + + if verbose { + println!(" Layer search events: {:?}", layer_searches); + println!(" Open node events: {:?}", opened_nodes); + println!(" Evaluate distance events: {:?}", distance_evals); + println!(" Compare distance events: {:?}", distance_comps); + } else { + println!( + "{:?}, {:?}, {:?}", + opened_nodes, distance_evals, distance_comps + ); + } +} + +fn configure_tracing() -> ( + Arc, + Arc>>, +) { + let counters = Arc::new(EventCounter::default()); + + let counting_layer = HnswEventCounterLayer { + counters: counters.clone(), + }; + + let counter_map: Arc>> = + Arc::new(Mutex::new(HashMap::new())); + + let vertex_openings_layer = VertexOpeningsLayer { + counter_map: counter_map.clone(), + }; + + tracing_subscriber::registry() + .with(counting_layer) + .with(vertex_openings_layer) + .init(); + + // tracing_subscriber::fmt() + // .init(); + + (counters, counter_map) +} diff --git a/iris-mpc-cpu/src/hawkers/plaintext_store.rs b/iris-mpc-cpu/src/hawkers/plaintext_store.rs index 7fcbcd10c..e03dc3e1b 100644 --- a/iris-mpc-cpu/src/hawkers/plaintext_store.rs +++ b/iris-mpc-cpu/src/hawkers/plaintext_store.rs @@ -1,4 +1,7 @@ -use crate::hnsw::HnswSearcher; +use crate::hnsw::{ + metrics::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT}, + HnswSearcher, +}; use aes_prng::AesRng; use hawk_pack::{graph_store::GraphMem, VectorStore}; use iris_mpc_common::iris_db::{ @@ -117,6 +120,7 @@ impl VectorStore for PlaintextStore { query: &Self::QueryRef, vector: &Self::VectorRef, ) -> Self::DistanceRef { + tracing::info!(event_type = EVAL_DIST_EVENT); let query_code = &self.points[*query]; let vector_code = &self.points[*vector]; query_code.data.distance_fraction(&vector_code.data) @@ -132,6 +136,7 @@ impl VectorStore for PlaintextStore { distance1: &Self::DistanceRef, distance2: &Self::DistanceRef, ) -> bool { + tracing::info!(event_type = COMPARE_DIST_EVENT); let (a, b) = *distance1; // a/b let (c, d) = *distance2; // c/d (a as i32) * (d as i32) - (b as i32) * (c as i32) < 0 diff --git a/iris-mpc-cpu/src/hnsw/metrics.rs b/iris-mpc-cpu/src/hnsw/metrics.rs new file mode 100644 index 000000000..2c8cb4f86 --- /dev/null +++ b/iris-mpc-cpu/src/hnsw/metrics.rs @@ -0,0 +1,153 @@ +use std::{ + collections::HashMap, + fmt::Debug, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, +}; +use tracing::{ + field::{Field, Visit}, + span::Attributes, + Event, Id, Subscriber, +}; +use tracing_subscriber::{ + layer::{Context, Layer}, + registry::LookupSpan, +}; + +pub const LAYER_SEARCH_EVENT: u64 = 0; +pub const OPEN_NODE_EVENT: u64 = 1; +pub const EVAL_DIST_EVENT: u64 = 2; +pub const COMPARE_DIST_EVENT: u64 = 3; + +const NUM_EVENT_TYPES: usize = 4; + +#[derive(Default)] +pub struct EventCounter { + pub counters: [AtomicUsize; NUM_EVENT_TYPES], +} + +pub struct HnswEventCounterLayer { + pub counters: Arc, +} + +impl Layer for HnswEventCounterLayer +where + S: Subscriber + for<'a> LookupSpan<'a>, +{ + fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) { + let mut visitor = EventVisitor::default(); + event.record(&mut visitor); + + if let Some(event_type) = visitor.event { + if let Some(counter) = self.counters.counters.get(event_type as usize) { + let increment_amount = visitor.amount.unwrap_or(1); + counter.fetch_add(increment_amount as usize, Ordering::Relaxed); + } else { + panic!("Invalid event type specified: {:?}", event_type); + } + } + } +} + +#[derive(Default)] +struct EventVisitor { + // which event was encountered + event: Option, + + // how much to increment the associated counter + amount: Option, +} + +impl Visit for EventVisitor { + fn record_u64(&mut self, field: &Field, value: u64) { + match field.name() { + "event_type" => { + self.event = Some(value); + } + "increment_amount" => { + self.amount = Some(value); + } + _ => {} + } + } + + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} +} + +/// Tracing library Layer for counting detailed HNSW layer search operations +pub struct VertexOpeningsLayer { + // Measure number of vertex openings for different lc and ef values + pub counter_map: Arc>>, +} + +impl Layer for VertexOpeningsLayer +where + S: Subscriber + for<'a> LookupSpan<'a>, +{ + // fn register_callsite(&self, _metadata: &'static Metadata<'static>) -> + // Interest { Interest::sometimes() + // } + + // fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool { + // let is_search_layer_span = metadata.is_span() && metadata.name() == + // "search_layer"; is_search_layer_span || metadata.is_event() + // } + + fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) { + let span = ctx.span(id).unwrap(); + let mut visitor = LayerSearchFields::default(); + attrs.record(&mut visitor); + span.extensions_mut().insert(visitor); + } + + fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) { + let mut visitor = EventVisitor::default(); + event.record(&mut visitor); + + if let Some(OPEN_NODE_EVENT) = visitor.event { + // open node event must have parent span representing open node function + let current_span = ctx.current_span(); + let span_id = current_span.id().unwrap(); + if let Some(LayerSearchFields { + lc: Some(lc), + ef: Some(ef), + }) = ctx + .span(span_id) + .unwrap() + .extensions() + .get::() + { + let mut counter_map = self.counter_map.lock().unwrap(); + let increment_amount = visitor.amount.unwrap_or(1); + *counter_map + .entry((*lc as usize, *ef as usize)) + .or_insert(0usize) += increment_amount as usize; + } else { + panic!("Open node event is missing associated span fields"); + } + } + } +} + +#[derive(Default)] +pub struct LayerSearchFields { + lc: Option, + ef: Option, +} + +impl Visit for LayerSearchFields { + fn record_u64(&mut self, field: &Field, value: u64) { + match field.name() { + "lc" => { + self.lc = Some(value); + } + "ef" => { + self.ef = Some(value); + } + _ => {} + } + } + fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {} +} diff --git a/iris-mpc-cpu/src/hnsw/mod.rs b/iris-mpc-cpu/src/hnsw/mod.rs index 8cc4cecef..edcae0c1c 100644 --- a/iris-mpc-cpu/src/hnsw/mod.rs +++ b/iris-mpc-cpu/src/hnsw/mod.rs @@ -1,3 +1,4 @@ +pub mod metrics; pub mod searcher; pub use searcher::HnswSearcher; diff --git a/iris-mpc-cpu/src/hnsw/searcher.rs b/iris-mpc-cpu/src/hnsw/searcher.rs index 1dd72a1b5..b3041e618 100644 --- a/iris-mpc-cpu/src/hnsw/searcher.rs +++ b/iris-mpc-cpu/src/hnsw/searcher.rs @@ -4,6 +4,7 @@ //* //* https://github.com/Inversed-Tech/hawk-pack/ +use super::metrics; pub use hawk_pack::data_structures::queue::{ FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV, }; @@ -12,6 +13,7 @@ use rand::RngCore; use rand_distr::{Distribution, Geometric}; use serde::{Deserialize, Serialize}; use std::collections::HashSet; +use tracing::{info, instrument}; // Specify construction and search parameters by layer up to this value minus 1 // any higher layers will use the last set of parameters @@ -242,6 +244,7 @@ impl HnswSearcher { /// given layer using depth-first graph traversal, Terminates when `W` /// contains vectors which are the nearest to `q` among all traversed /// vertices and their neighbors. + #[instrument(skip(self, vector_store, graph_store, W))] #[allow(non_snake_case)] async fn search_layer>( &self, @@ -252,6 +255,8 @@ impl HnswSearcher { ef: usize, lc: usize, ) { + info!(event_type = metrics::LAYER_SEARCH_EVENT); + // v: The set of already visited vectors. let mut v = HashSet::::from_iter(W.iter().map(|(e, _eq)| e.clone())); @@ -271,6 +276,7 @@ impl HnswSearcher { } // Open the node c and explore its neighbors. + info!(event_type = metrics::OPEN_NODE_EVENT); // Visit all neighbors of c. let c_links = graph_store.get_links(&c, lc).await;