Skip to content

Commit

Permalink
Basic cpu iris search using HNSW (#327)
Browse files Browse the repository at this point in the history
* added plaintext store; basic cpu iris search using hnsw

* updated git reference to hawkpack

* added prototype for hnsw using rep3 only

* added common dependencies into main workspace

* some renaming

* fix tiny bug

* update dependencies

* removed un-necessary from

* using IrisDB methods to generate database for testing

* replace non-blocking names with default names

* added expect instead of eyre error handling
  • Loading branch information
rdragos authored Sep 11, 2024
1 parent 1d3571b commit 6ec26a3
Show file tree
Hide file tree
Showing 14 changed files with 1,294 additions and 29 deletions.
46 changes: 44 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ eyre = "0.6"
futures = "0.3.30"
hex = "0.4.3"
itertools = "0.13"
num-traits = "0.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
sqlx = { version = "0.7", features = ["runtime-tokio-native-tls", "postgres"] }
Expand All @@ -35,6 +36,7 @@ rayon = "1.5.1"
reqwest = { version = "0.12", features = ["blocking", "json"] }
static_assertions = "1.1"
telemetry-batteries = { git = "https://github.com/worldcoin/telemetry-batteries.git", rev = "802a4f39f358e077b11c8429b4c65f3e45b85959" }
thiserror = "1"
tokio = { version = "1.40", features = ["full", "rt-multi-thread"] }
uuid = { version = "1", features = ["v4"] }

Expand Down
2 changes: 1 addition & 1 deletion iris-mpc-common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ clap.workspace = true
rand.workspace = true
bytemuck.workspace = true
eyre.workspace = true
thiserror = "1"
thiserror.workspace = true
rayon.workspace = true
itertools.workspace = true
base64.workspace = true
Expand Down
14 changes: 11 additions & 3 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,20 @@ bytes = "1.7"
bytemuck.workspace = true
eyre.workspace = true
futures.workspace = true
hawk-pack = { git = "https://github.com/Inversed-Tech/hawk-pack.git" }
iris-mpc-common = { path = "../iris-mpc-common" }
num-traits = "0.2"
itertools.workspace = true
num-traits.workspace = true
rand.workspace = true
rayon.workspace = true
serde.workspace = true
static_assertions.workspace = true
thiserror = "1.0"
tokio.workspace = true
tracing.workspace = true
tracing-test = "0.2.5"

[dev-dependencies]
criterion = { version = "0.5.1", features = ["async_tokio"] }

[[bench]]
name = "hnsw"
harness = false
102 changes: 102 additions & 0 deletions iris-mpc-cpu/benches/hnsw.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
use aes_prng::AesRng;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, SamplingMode};
use hawk_pack::{graph_store::GraphMem, hnsw_db::HawkSearcher, VectorStore};
use iris_mpc_common::iris_db::{db::IrisDB, iris::IrisCode};
use iris_mpc_cpu::{
database_generators::generate_iris_shares,
hawkers::{aby3_store::create_ready_made_hawk_searcher, plaintext_store::PlaintextStore},
};
use rand::SeedableRng;

fn bench_plaintext_hnsw(c: &mut Criterion) {
let mut group = c.benchmark_group("plaintext_hnsw");
group.sample_size(10);
group.sampling_mode(SamplingMode::Flat);

for database_size in [100_usize, 1000, 10000] {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

let plain_searcher = rt.block_on(async move {
let mut rng = AesRng::seed_from_u64(0_u64);
let vector_store = PlaintextStore::default();
let graph_store = GraphMem::new();
let mut plain_searcher = HawkSearcher::new(vector_store, graph_store, &mut rng);

for _ in 0..database_size {
let raw_query = IrisCode::random_rng(&mut rng);
let query = plain_searcher.vector_store.prepare_query(raw_query.clone());
let neighbors = plain_searcher.search_to_insert(&query).await;
let inserted = plain_searcher.vector_store.insert(&query).await;
plain_searcher
.insert_from_search_results(inserted, neighbors)
.await;
}
plain_searcher
});

group.bench_function(BenchmarkId::new("insert", database_size), |b| {
b.to_async(&rt).iter_batched(
|| plain_searcher.clone(),
|mut my_db| async move {
let mut rng = AesRng::seed_from_u64(0_u64);
let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone();
let query = my_db.vector_store.prepare_query(on_the_fly_query);
let neighbors = my_db.search_to_insert(&query).await;
my_db.insert_from_search_results(query, neighbors).await;
},
criterion::BatchSize::SmallInput,
)
});
}
group.finish();
}

fn bench_ready_made_hnsw(c: &mut Criterion) {
let mut group = c.benchmark_group("ready_made_hnsw");
group.sample_size(10);

for database_size in [1, 10, 100, 1000] {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();

let (_, secret_searcher) = rt.block_on(async move {
let mut rng = AesRng::seed_from_u64(0_u64);
create_ready_made_hawk_searcher(&mut rng, database_size)
.await
.unwrap()
});

group.bench_function(
BenchmarkId::new("big-hnsw-insertions", database_size),
|b| {
b.to_async(&rt).iter_batched(
|| secret_searcher.clone(),
|mut my_db| async move {
let mut rng = AesRng::seed_from_u64(0_u64);
let on_the_fly_query = IrisDB::new_random_rng(1, &mut rng).db[0].clone();
let raw_query = generate_iris_shares(&mut rng, on_the_fly_query);

let query = my_db.vector_store.prepare_query(raw_query);
let neighbors = my_db.search_to_insert(&query).await;
my_db.insert_from_search_results(query, neighbors).await;
},
criterion::BatchSize::SmallInput,
)
},
);
}
group.finish();
}

criterion_group! {
hnsw,
bench_plaintext_hnsw,
bench_ready_made_hnsw
}

criterion_main!(hnsw);
110 changes: 110 additions & 0 deletions iris-mpc-cpu/src/database_generators.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
use crate::shares::{ring_impl::RingElement, share::Share, vecshare::VecShare};
use iris_mpc_common::iris_db::iris::{IrisCode, IrisCodeArray};
use rand::{Rng, RngCore};
use std::sync::Arc;

type ShareRing = u16;
type ShareType = Share<ShareRing>;
type VecShareType = VecShare<u16>;
type ShareRingPlain = RingElement<ShareRing>;
// type ShareType = Share<u16>;

#[derive(PartialEq, Eq, Debug, Default, Clone)]
pub struct SharedIris {
pub shares: VecShareType,
pub mask: IrisCodeArray,
}

#[derive(Clone)]
pub struct SharedDB {
pub shares: Arc<Vec<VecShareType>>,
pub masks: Arc<Vec<IrisCodeArray>>,
}

pub struct RawSharedDatabase {
pub player0_shares: Vec<SharedIris>,
pub player1_shares: Vec<SharedIris>,
pub player2_shares: Vec<SharedIris>,
}

/// This one is taken from iris-mpc-semi/iris.rs
pub struct IrisShare {}
impl IrisShare {
pub fn get_shares<R: RngCore>(input: bool, mask: bool, rng: &mut R) -> Vec<ShareType> {
let val = RingElement((input & mask) as ShareRing);
let to_share = RingElement(mask as ShareRing) - val - val;

let a = rng.gen::<ShareRingPlain>();
let b = rng.gen::<ShareRingPlain>();
let c = to_share - a - b;

let share1 = Share::new(a, c);
let share2 = Share::new(b, a);
let share3 = Share::new(c, b);

vec![share1, share2, share3]
}
}

pub(crate) fn create_shared_database_raw<R: RngCore>(
rng: &mut R,
in_mem: &[IrisCode],
) -> eyre::Result<RawSharedDatabase> {
let mut shared_irises = (0..3)
.map(|_| Vec::with_capacity(in_mem.len()))
.collect::<Vec<_>>();
for code in in_mem.iter() {
let shared_code: Vec<_> = (0..IrisCode::IRIS_CODE_SIZE)
.map(|i| IrisShare::get_shares(code.code.get_bit(i), code.mask.get_bit(i), rng))
.collect();

let shared_3_n: Vec<_> = (0..3)
.map(|p_id| {
let shared_n: Vec<Share<u16>> = (0..IrisCode::IRIS_CODE_SIZE)
.map(|iris_index| shared_code[iris_index][p_id].clone())
.collect();
shared_n
})
.collect();
// We simulate the parties already knowing the shares of the code.
for party_id in 0..3 {
shared_irises[party_id].push(SharedIris {
shares: VecShareType::new_vec(shared_3_n[party_id].clone()),
mask: code.mask,
});
}
}
let player2_shares = shared_irises
.pop()
.expect("error popping shared iris for player 2");
let player1_shares = shared_irises
.pop()
.expect("error popping shared iris for player 1");
let player0_shares = shared_irises
.pop()
.expect("error popping shared iris for player 0");
Ok(RawSharedDatabase {
player0_shares,
player1_shares,
player2_shares,
})
}

pub fn generate_iris_shares<R: Rng>(rng: &mut R, iris: IrisCode) -> Vec<SharedIris> {
let mut res = vec![SharedIris::default(); 3];
for res_i in res.iter_mut() {
res_i
.mask
.as_raw_mut_slice()
.copy_from_slice(iris.mask.as_raw_slice());
}

for i in 0..IrisCode::IRIS_CODE_SIZE {
// We simulate the parties already knowing the shares of the code.
let shares = IrisShare::get_shares(iris.code.get_bit(i), iris.mask.get_bit(i), rng);
for party_id in 0..3 {
res[party_id].shares.push(shares[party_id].to_owned());
}
}
res
}
Loading

0 comments on commit 6ec26a3

Please sign in to comment.