Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Initial HNSW algorithm metrics using tracing library #911

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ name = "hnsw-ex"
name = "local_hnsw"
path = "bin/local_hnsw.rs"

[[bin]]
name = "hnsw_algorithm_metrics"
path = "bin/hnsw_algorithm_metrics.rs"

[[bin]]
name = "generate_benchmark_data"
path = "bin/generate_benchmark_data.rs"
129 changes: 129 additions & 0 deletions iris-mpc-cpu/bin/hnsw_algorithm_metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use aes_prng::AesRng;
use clap::Parser;
use hawk_pack::graph_store::GraphMem;
use iris_mpc_common::iris_db::iris::IrisCode;
use iris_mpc_cpu::{
hawkers::plaintext_store::PlaintextStore,
hnsw::{
metrics::{
EventCounter, HnswEventCounterLayer, VertexOpeningsLayer, COMPARE_DIST_EVENT,
EVAL_DIST_EVENT, LAYER_SEARCH_EVENT, OPEN_NODE_EVENT,
},
searcher::{HnswParams, HnswSearcher},
},
};
use rand::SeedableRng;
use std::{
collections::HashMap,
error::Error,
sync::{Arc, Mutex},
};
use tracing_subscriber::prelude::*;

#[derive(Parser)]
#[allow(non_snake_case)]
struct Args {
#[clap(default_value = "64")]
M: usize,
#[clap(default_value = "128")]
ef_constr: usize,
#[clap(default_value = "64")]
ef_search: usize,
#[clap(default_value = "10000")]
database_size: usize,
layer_probability: Option<f64>,
}

#[tokio::main]
#[allow(non_snake_case)]
async fn main() -> Result<(), Box<dyn Error>> {
let args = Args::parse();
let M = args.M;
let ef_constr = args.ef_constr;
let ef_search = args.ef_search;
let database_size = args.database_size;
let layer_probability = args.layer_probability;

let (counters, counter_map) = configure_tracing();

let mut rng = AesRng::seed_from_u64(42_u64);
// let mut rng = rand::thread_rng();
let mut vector = PlaintextStore::default();
let mut graph = GraphMem::new();
let params = if let Some(p) = layer_probability {
HnswParams::new_with_layer_probability(ef_constr, ef_search, M, p)
} else {
HnswParams::new(ef_constr, ef_search, M)
};
let searcher = HnswSearcher { params };

for idx in 0..database_size {
let raw_query = IrisCode::random_rng(&mut rng);
let query = vector.prepare_query(raw_query.clone());
searcher
.insert(&mut vector, &mut graph, &query, &mut rng)
.await;

if idx % 1000 == 999 {
print!("{}, ", idx + 1);
print_stats(&counters, false);
}
}

println!("Final counts:");
print_stats(&counters, true);

println!("Layer search counts:");
for ((lc, ef), value) in counter_map.lock().unwrap().iter() {
println!(" lc={lc},ef={ef}: {value}");
}

Ok(())
}

fn print_stats(counters: &Arc<EventCounter>, verbose: bool) {
let layer_searches = counters.counters.get(LAYER_SEARCH_EVENT as usize).unwrap();
let opened_nodes = counters.counters.get(OPEN_NODE_EVENT as usize).unwrap();
let distance_evals = counters.counters.get(EVAL_DIST_EVENT as usize).unwrap();
let distance_comps = counters.counters.get(COMPARE_DIST_EVENT as usize).unwrap();

if verbose {
println!(" Layer search events: {:?}", layer_searches);
println!(" Open node events: {:?}", opened_nodes);
println!(" Evaluate distance events: {:?}", distance_evals);
println!(" Compare distance events: {:?}", distance_comps);
} else {
println!(
"{:?}, {:?}, {:?}",
opened_nodes, distance_evals, distance_comps
);
}
}

fn configure_tracing() -> (
Arc<EventCounter>,
Arc<Mutex<HashMap<(usize, usize), usize>>>,
) {
let counters = Arc::new(EventCounter::default());

let counting_layer = HnswEventCounterLayer {
counters: counters.clone(),
};

let counter_map: Arc<Mutex<HashMap<(usize, usize), usize>>> =
Arc::new(Mutex::new(HashMap::new()));

let vertex_openings_layer = VertexOpeningsLayer {
counter_map: counter_map.clone(),
};

tracing_subscriber::registry()
.with(counting_layer)
.with(vertex_openings_layer)
.init();

// tracing_subscriber::fmt()
// .init();

(counters, counter_map)
}
7 changes: 6 additions & 1 deletion iris-mpc-cpu/src/hawkers/plaintext_store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use crate::hnsw::HnswSearcher;
use crate::hnsw::{
metrics::{COMPARE_DIST_EVENT, EVAL_DIST_EVENT},
HnswSearcher,
};
use aes_prng::AesRng;
use hawk_pack::{graph_store::GraphMem, VectorStore};
use iris_mpc_common::iris_db::{
Expand Down Expand Up @@ -117,6 +120,7 @@ impl VectorStore for PlaintextStore {
query: &Self::QueryRef,
vector: &Self::VectorRef,
) -> Self::DistanceRef {
tracing::info!(event_type = EVAL_DIST_EVENT);
let query_code = &self.points[*query];
let vector_code = &self.points[*vector];
query_code.data.distance_fraction(&vector_code.data)
Expand All @@ -132,6 +136,7 @@ impl VectorStore for PlaintextStore {
distance1: &Self::DistanceRef,
distance2: &Self::DistanceRef,
) -> bool {
tracing::info!(event_type = COMPARE_DIST_EVENT);
let (a, b) = *distance1; // a/b
let (c, d) = *distance2; // c/d
(a as i32) * (d as i32) - (b as i32) * (c as i32) < 0
Expand Down
153 changes: 153 additions & 0 deletions iris-mpc-cpu/src/hnsw/metrics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::{
collections::HashMap,
fmt::Debug,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use tracing::{
field::{Field, Visit},
span::Attributes,
Event, Id, Subscriber,
};
use tracing_subscriber::{
layer::{Context, Layer},
registry::LookupSpan,
};

pub const LAYER_SEARCH_EVENT: u64 = 0;
pub const OPEN_NODE_EVENT: u64 = 1;
pub const EVAL_DIST_EVENT: u64 = 2;
pub const COMPARE_DIST_EVENT: u64 = 3;

const NUM_EVENT_TYPES: usize = 4;

#[derive(Default)]
pub struct EventCounter {
pub counters: [AtomicUsize; NUM_EVENT_TYPES],
}

pub struct HnswEventCounterLayer {
pub counters: Arc<EventCounter>,
}

impl<S> Layer<S> for HnswEventCounterLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
fn on_event(&self, event: &Event<'_>, _ctx: Context<'_, S>) {
let mut visitor = EventVisitor::default();
event.record(&mut visitor);

if let Some(event_type) = visitor.event {
if let Some(counter) = self.counters.counters.get(event_type as usize) {
let increment_amount = visitor.amount.unwrap_or(1);
counter.fetch_add(increment_amount as usize, Ordering::Relaxed);
} else {
panic!("Invalid event type specified: {:?}", event_type);
}
}
}
}

#[derive(Default)]
struct EventVisitor {
// which event was encountered
event: Option<u64>,

// how much to increment the associated counter
amount: Option<u64>,
}

impl Visit for EventVisitor {
fn record_u64(&mut self, field: &Field, value: u64) {
match field.name() {
"event_type" => {
self.event = Some(value);
}
"increment_amount" => {
self.amount = Some(value);
}
_ => {}
}
}

fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {}
}

/// Tracing library Layer for counting detailed HNSW layer search operations
pub struct VertexOpeningsLayer {
// Measure number of vertex openings for different lc and ef values
pub counter_map: Arc<Mutex<HashMap<(usize, usize), usize>>>,
}

impl<S> Layer<S> for VertexOpeningsLayer
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
// fn register_callsite(&self, _metadata: &'static Metadata<'static>) ->
// Interest { Interest::sometimes()
// }

// fn enabled(&self, metadata: &Metadata<'_>, _ctx: Context<'_, S>) -> bool {
// let is_search_layer_span = metadata.is_span() && metadata.name() ==
// "search_layer"; is_search_layer_span || metadata.is_event()
// }

fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
let span = ctx.span(id).unwrap();
let mut visitor = LayerSearchFields::default();
attrs.record(&mut visitor);
span.extensions_mut().insert(visitor);
}

fn on_event(&self, event: &Event<'_>, ctx: Context<'_, S>) {
let mut visitor = EventVisitor::default();
event.record(&mut visitor);

if let Some(OPEN_NODE_EVENT) = visitor.event {
// open node event must have parent span representing open node function
let current_span = ctx.current_span();
let span_id = current_span.id().unwrap();
if let Some(LayerSearchFields {
lc: Some(lc),
ef: Some(ef),
}) = ctx
.span(span_id)
.unwrap()
.extensions()
.get::<LayerSearchFields>()
{
let mut counter_map = self.counter_map.lock().unwrap();
let increment_amount = visitor.amount.unwrap_or(1);
*counter_map
.entry((*lc as usize, *ef as usize))
.or_insert(0usize) += increment_amount as usize;
} else {
panic!("Open node event is missing associated span fields");
}
}
}
}

#[derive(Default)]
pub struct LayerSearchFields {
lc: Option<u64>,
ef: Option<u64>,
}

impl Visit for LayerSearchFields {
fn record_u64(&mut self, field: &Field, value: u64) {
match field.name() {
"lc" => {
self.lc = Some(value);
}
"ef" => {
self.ef = Some(value);
}
_ => {}
}
}
fn record_debug(&mut self, _field: &Field, _value: &dyn Debug) {}
}
1 change: 1 addition & 0 deletions iris-mpc-cpu/src/hnsw/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod metrics;
pub mod searcher;

pub use searcher::HnswSearcher;
6 changes: 6 additions & 0 deletions iris-mpc-cpu/src/hnsw/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//*
//* https://github.com/Inversed-Tech/hawk-pack/

use super::metrics;
pub use hawk_pack::data_structures::queue::{
FurthestQueue, FurthestQueueV, NearestQueue, NearestQueueV,
};
Expand All @@ -12,6 +13,7 @@ use rand::RngCore;
use rand_distr::{Distribution, Geometric};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use tracing::{info, instrument};

// Specify construction and search parameters by layer up to this value minus 1
// any higher layers will use the last set of parameters
Expand Down Expand Up @@ -242,6 +244,7 @@ impl HnswSearcher {
/// given layer using depth-first graph traversal, Terminates when `W`
/// contains vectors which are the nearest to `q` among all traversed
/// vertices and their neighbors.
#[instrument(skip(self, vector_store, graph_store, W))]
#[allow(non_snake_case)]
async fn search_layer<V: VectorStore, G: GraphStore<V>>(
&self,
Expand All @@ -252,6 +255,8 @@ impl HnswSearcher {
ef: usize,
lc: usize,
) {
info!(event_type = metrics::LAYER_SEARCH_EVENT);

// v: The set of already visited vectors.
let mut v = HashSet::<V::VectorRef>::from_iter(W.iter().map(|(e, _eq)| e.clone()));

Expand All @@ -271,6 +276,7 @@ impl HnswSearcher {
}

// Open the node c and explore its neighbors.
info!(event_type = metrics::OPEN_NODE_EVENT);

// Visit all neighbors of c.
let c_links = graph_store.get_links(&c, lc).await;
Expand Down
Loading