Skip to content

Commit

Permalink
parallel-load (#213)
Browse files Browse the repository at this point in the history
Co-authored-by: Aurélien Nicolas <[email protected]>
  • Loading branch information
naure and Aurélien Nicolas authored Aug 7, 2024
1 parent 4402c57 commit 8cb8d93
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 15 deletions.
1 change: 1 addition & 0 deletions .env.mpc1.dist
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SMPC__SERVICE__SERVICE_NAME=smpcv2-server
SMPC__DATABASE__URL=postgres://postgres:postgres@localhost/postgres
SMPC__DATABASE__MIGRATE=true
SMPC__DATABASE__CREATE=false
SMPC__DATABASE__LOAD_PARALLELISM=8

# AWS Configuration
SMPC__AWS__REGION=eu-north-1
Expand Down
1 change: 1 addition & 0 deletions .env.mpc2.dist
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SMPC__SERVICE__SERVICE_NAME=smpcv2-server
SMPC__DATABASE__URL=postgres://postgres:postgres@localhost/postgres
SMPC__DATABASE__MIGRATE=true
SMPC__DATABASE__CREATE=false
SMPC__DATABASE__LOAD_PARALLELISM=8

# AWS Configuration
SMPC__AWS__REGION=eu-north-1
Expand Down
1 change: 1 addition & 0 deletions .env.mpc3.dist
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ SMPC__SERVICE__SERVICE_NAME=smpcv2-server
SMPC__DATABASE__URL=postgres://postgres:postgres@localhost/postgres
SMPC__DATABASE__MIGRATE=true
SMPC__DATABASE__CREATE=false
SMPC__DATABASE__LOAD_PARALLELISM=8

# AWS Configuration
SMPC__AWS__REGION=eu-north-1
Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/mpc1-stage/values-gpu-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ env:
- name: SMPC__DATABASE__CREATE
value: "true"

- name: SMPC__DATABASE__LOAD_PARALLELISM
value: "8"

- name: SMPC__AWS__REGION
value: "eu-north-1"

Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/mpc2-stage/values-gpu-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ env:
- name: SMPC__DATABASE__CREATE
value: "true"

- name: SMPC__DATABASE__LOAD_PARALLELISM
value: "8"

- name: SMPC__AWS__REGION
value: "eu-north-1"

Expand Down
3 changes: 3 additions & 0 deletions deploy/stage/mpc3-stage/values-gpu-iris-mpc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ env:
- name: SMPC__DATABASE__CREATE
value: "true"

- name: SMPC__DATABASE__LOAD_PARALLELISM
value: "8"

- name: SMPC__AWS__REGION
value: "eu-north-1"

Expand Down
34 changes: 25 additions & 9 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ async fn initialize_chacha_seeds(
async fn initialize_iris_dbs(
party_id: usize,
store: &Store,
config: &Config,
) -> eyre::Result<(Vec<u16>, Vec<u16>, usize)> {
// Generate or load DB
let (mut codes_db, mut masks_db) = {
Expand Down Expand Up @@ -262,22 +263,36 @@ async fn initialize_iris_dbs(

(codes_db, masks_db)
};
let fake_len = codes_db.len();

let count_irises = store.count_irises().await?;
codes_db.reserve(count_irises * IRIS_CODE_LENGTH);
masks_db.reserve(count_irises * IRIS_CODE_LENGTH);

tracing::info!("Initialize iris db: Loading from DB");
codes_db.resize(fake_len + count_irises * IRIS_CODE_LENGTH, 0);
masks_db.resize(fake_len + count_irises * IRIS_CODE_LENGTH, 0);

let parallelism = config
.database
.as_ref()
.ok_or(eyre!("Missing database config"))?
.load_parallelism;

tracing::info!(
"Initialize iris db: Loading from DB (parallelism: {})",
parallelism
);
// Load DB from persistent storage.
let mut store_len = 0;
while let Some(iris) = store.stream_irises().await.next().await {
while let Some(iris) = store.stream_irises_par(parallelism).await.next().await {
let iris = iris?;
if iris.index() >= count_irises {
tracing::warn!("Inconsistent iris index {}", iris.index());
continue;
}

codes_db.extend(iris.left_code());
masks_db.extend(iris.left_mask());
let start = fake_len + iris.index() * IRIS_CODE_LENGTH;
codes_db[start..start + IRIS_CODE_LENGTH].copy_from_slice(iris.left_code());
masks_db[start..start + IRIS_CODE_LENGTH].copy_from_slice(iris.left_mask());

store_len += 1;

if (store_len % 10000) == 0 {
tracing::info!("Initialize iris db: Loaded {} entries from DB", store_len);
}
Expand Down Expand Up @@ -337,7 +352,8 @@ async fn main() -> eyre::Result<()> {
.await?;

tracing::info!("Initialize iris db");
let (mut codes_db, mut masks_db, store_len) = initialize_iris_dbs(party_id, &store).await?;
let (mut codes_db, mut masks_db, store_len) =
initialize_iris_dbs(party_id, &store, &config).await?;

let my_state = SyncState {
db_len: store_len as u64,
Expand Down
13 changes: 12 additions & 1 deletion src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,14 @@ pub struct Config {
#[serde(default)]
pub aws: Option<AwsConfig>,

#[serde(default)]
#[serde(default = "default_processing_timeout_secs")]
pub processing_timeout_secs: u64,
}

fn default_processing_timeout_secs() -> u64 {
60
}

impl Config {
pub fn load_config(prefix: &str) -> eyre::Result<Config> {
let settings = config::Config::builder();
Expand Down Expand Up @@ -87,6 +91,13 @@ pub struct DbConfig {

#[serde(default)]
pub create: bool,

#[serde(default = "default_load_parallelism")]
pub load_parallelism: usize,
}

fn default_load_parallelism() -> usize {
8
}

impl fmt::Debug for DbConfig {
Expand Down
49 changes: 44 additions & 5 deletions src/store/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@ pub mod sync;
use crate::config::Config;
use bytemuck::cast_slice;
use eyre::{eyre, Result};
use futures::Stream;
use futures::{
stream::{self},
Stream,
};
use sqlx::{migrate::Migrator, postgres::PgPoolOptions, Executor, PgPool, Postgres, Transaction};
use std::ops::DerefMut;
use std::{ops::DerefMut, pin::Pin};

const APP_NAME: &str = "SMPC";
const POOL_SIZE: u32 = 5;
const MAX_CONNECTIONS: u32 = 100;

static MIGRATOR: Migrator = sqlx::migrate!("./migrations");

Expand All @@ -23,7 +26,7 @@ fn sql_switch_schema(schema_name: &str) -> Result<String> {
))
}

#[derive(sqlx::FromRow, Debug, Default)]
#[derive(sqlx::FromRow, Debug, Default, PartialEq, Eq)]
pub struct StoredIris {
#[allow(dead_code)]
id: i64, // BIGSERIAL
Expand All @@ -34,6 +37,11 @@ pub struct StoredIris {
}

impl StoredIris {
/// The index which is contiguous and starts from 0.
pub fn index(&self) -> usize {
self.id as usize
}

pub fn left_code(&self) -> &[u16] {
cast_u8_to_u16(&self.left_code)
}
Expand Down Expand Up @@ -87,7 +95,7 @@ impl Store {
let connect_sql = sql_switch_schema(schema_name)?;

let pool = PgPoolOptions::new()
.max_connections(POOL_SIZE)
.max_connections(MAX_CONNECTIONS)
.after_connect(move |conn, _meta| {
// Switch to the given schema in every connection.
let connect_sql = connect_sql.clone();
Expand Down Expand Up @@ -118,10 +126,37 @@ impl Store {
Ok(count.0 as usize)
}

/// Stream irises in order.
pub async fn stream_irises(&self) -> impl Stream<Item = Result<StoredIris, sqlx::Error>> + '_ {
sqlx::query_as::<_, StoredIris>("SELECT * FROM irises ORDER BY id").fetch(&self.pool)
}

/// Stream irises in parallel, without a particular order.
pub async fn stream_irises_par(
&self,
partitions: usize,
) -> impl Stream<Item = Result<StoredIris, sqlx::Error>> + '_ {
let count = self.count_irises().await.expect("Failed count_irises") - 1;
let partition_size = count.div_ceil(partitions);

let mut partition_streams = Vec::new();
for i in 0..partitions {
let start_id = partition_size * i;
let end_id = start_id + partition_size - 1;

let partition_stream =
sqlx::query_as::<_, StoredIris>("SELECT * FROM irises WHERE id BETWEEN $1 AND $2")
.bind(start_id as i64)
.bind(end_id as i64)
.fetch(&self.pool);

partition_streams.push(Box::pin(partition_stream)
as Pin<Box<dyn Stream<Item = Result<StoredIris, sqlx::Error>> + Send>>);
}

stream::select_all(partition_streams)
}

pub async fn insert_irises(
&self,
tx: &mut Transaction<'_, Postgres>,
Expand Down Expand Up @@ -348,6 +383,10 @@ mod tests {
assert_eq!(got.len(), count);
assert_contiguous_id(&got);

let mut got_par: Vec<StoredIris> = store.stream_irises_par(5).await.try_collect().await?;
got_par.sort_by_key(|iris| iris.id);
assert_eq!(got, got_par);

let got = store.last_results(count).await?;
assert_eq!(got, result_events);

Expand Down

0 comments on commit 8cb8d93

Please sign in to comment.