Skip to content

Commit

Permalink
manually manage serial id (#702)
Browse files Browse the repository at this point in the history
* manually manage serial id

* fix

* .

* .

* .

* .

* .

* make i64
  • Loading branch information
philsippl authored Nov 21, 2024
1 parent 00b4e61 commit e1ff515
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 115 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ALTER TABLE irises ALTER COLUMN id DROP IDENTITY IF EXISTS;
115 changes: 50 additions & 65 deletions iris-mpc-store/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl StoredIris {

#[derive(Clone)]
pub struct StoredIrisRef<'a> {
pub id: i64,
pub left_code: &'a [u16],
pub left_mask: &'a [u16],
pub right_code: &'a [u16],
Expand Down Expand Up @@ -190,9 +191,10 @@ impl Store {
return Ok(vec![]);
}
let mut query = sqlx::QueryBuilder::new(
"INSERT INTO irises (left_code, left_mask, right_code, right_mask)",
"INSERT INTO irises (id, left_code, left_mask, right_code, right_mask)",
);
query.push_values(codes_and_masks, |mut query, iris| {
query.push_bind(iris.id);
query.push_bind(cast_slice::<u16, u8>(iris.left_code));
query.push_bind(cast_slice::<u16, u8>(iris.left_mask));
query.push_bind(cast_slice::<u16, u8>(iris.right_code));
Expand Down Expand Up @@ -290,27 +292,6 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask
Ok(())
}

async fn set_sequence_id(
&self,
id: usize,
executor: impl sqlx::Executor<'_, Database = Postgres>,
) -> Result<()> {
if id == 0 {
// If requested id is 0 (only used in tests), reset the sequence to 1 with
// advance_nextval set to false. This is because serial id starts from 1.
sqlx::query("SELECT setval(pg_get_serial_sequence('irises', 'id'), 1, false)")
.execute(executor)
.await?;
} else {
sqlx::query("SELECT setval(pg_get_serial_sequence('irises', 'id'), $1, true)")
.bind(id as i64)
.execute(executor)
.await?;
}

Ok(())
}

pub async fn rollback(&self, db_len: usize) -> Result<()> {
let mut tx = self.pool.begin().await?;

Expand All @@ -319,18 +300,12 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask
.execute(&mut *tx)
.await?;

self.set_sequence_id(db_len, &mut *tx).await?;

tx.commit().await?;
Ok(())
}

pub async fn set_irises_sequence_id(&self, id: usize) -> Result<()> {
self.set_sequence_id(id, &self.pool).await
}

pub async fn get_irises_sequence_id(&self) -> Result<usize> {
let id: (i64,) = sqlx::query_as("SELECT last_value FROM irises_id_seq")
pub async fn get_max_serial_id(&self) -> Result<usize> {
let id: (i64,) = sqlx::query_as("SELECT MAX(id) FROM irises")
.fetch_one(&self.pool)
.await?;
Ok(id.0 as usize)
Expand All @@ -353,16 +328,6 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask
Ok(())
}

pub async fn update_iris_id_sequence(&self) -> Result<()> {
sqlx::query(
"SELECT setval(pg_get_serial_sequence('irises', 'id'), COALESCE(MAX(id), 0), true) \
FROM irises",
)
.execute(&self.pool)
.await?;
Ok(())
}

pub async fn last_results(&self, count: usize) -> Result<Vec<String>> {
let mut result_events: Vec<String> =
sqlx::query_scalar("SELECT result_event FROM results ORDER BY id DESC LIMIT $1")
Expand Down Expand Up @@ -442,6 +407,7 @@ DO UPDATE SET right_code = EXCLUDED.right_code, right_mask = EXCLUDED.right_mask
// inserting shares and masks in the db. Reusing the same share and mask for
// left and right
self.insert_irises(&mut tx, &[StoredIrisRef {
id: (i + 1) as i64,
left_code: &share.coefs,
left_mask: &mask.coefs,
right_code: &share.coefs,
Expand Down Expand Up @@ -505,18 +471,21 @@ mod tests {

let codes_and_masks = &[
StoredIrisRef {
id: 1,
left_code: &[1, 2, 3, 4],
left_mask: &[5, 6, 7, 8],
right_code: &[9, 10, 11, 12],
right_mask: &[13, 14, 15, 16],
},
StoredIrisRef {
id: 2,
left_code: &[1117, 18, 19, 20],
left_mask: &[21, 1122, 23, 24],
right_code: &[25, 26, 1127, 28],
right_mask: &[29, 30, 31, 1132],
},
StoredIrisRef {
id: 3,
left_code: &[17, 18, 19, 20],
left_mask: &[21, 22, 23, 24],
// Empty is allowed until stereo is implemented.
Expand Down Expand Up @@ -568,18 +537,23 @@ mod tests {

#[tokio::test]
async fn test_insert_many() -> Result<()> {
let count = 1 << 3;
let count: usize = 1 << 3;

let schema_name = temporary_name();
let store = Store::new(&test_db_url()?, &schema_name).await?;

let iris = StoredIrisRef {
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 12800],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 12800],
};
let codes_and_masks = vec![iris; count];
let mut codes_and_masks = vec![];

for i in 0..count {
let iris = StoredIrisRef {
id: (i + 1) as i64,
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 12800],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 12800],
};
codes_and_masks.push(iris);
}

let result_event = serde_json::to_string(&UniquenessResult::new(
0,
Expand Down Expand Up @@ -641,15 +615,20 @@ mod tests {
let schema_name = temporary_name();
let store = Store::new(&test_db_url()?, &schema_name).await?;

let iris = StoredIrisRef {
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 12800],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 12800],
};
let mut irises = vec![];
for i in 0..10 {
let iris = StoredIrisRef {
id: (i + 1) as i64,
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 12800],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 12800],
};
irises.push(iris);
}

let mut tx = store.tx().await?;
store.insert_irises(&mut tx, &vec![iris; 10]).await?;
store.insert_irises(&mut tx, &irises).await?;
tx.commit().await?;
store.rollback(5).await?;

Expand Down Expand Up @@ -779,31 +758,37 @@ mod tests {
let store = Store::new(&test_db_url()?, &schema_name).await?;

// insert two irises into db
let iris = StoredIrisRef {
let iris1 = StoredIrisRef {
id: 1,
left_code: &[123_u16; 12800],
left_mask: &[456_u16; 6400],
right_code: &[789_u16; 12800],
right_mask: &[101_u16; 6400],
};
let mut iris2 = iris1.clone();
iris2.id = 2;

let mut tx = store.tx().await?;
store.insert_irises(&mut tx, &vec![iris.clone(); 2]).await?;
store
.insert_irises(&mut tx, &[iris1, iris2.clone()])
.await?;
tx.commit().await?;

// update iris with id 1 in db
let updated_left_code = GaloisRingIrisCodeShare {
id: 0,
id: 1,
coefs: [666_u16; 12800],
};
let updated_left_mask = GaloisRingTrimmedMaskCodeShare {
id: 0,
id: 1,
coefs: [777_u16; 6400],
};
let updated_right_code = GaloisRingIrisCodeShare {
id: 0,
id: 1,
coefs: [888_u16; 12800],
};
let updated_right_mask = GaloisRingTrimmedMaskCodeShare {
id: 0,
id: 1,
coefs: [999_u16; 6400],
};
store
Expand All @@ -825,10 +810,10 @@ mod tests {
assert_eq!(cast_u8_to_u16(&got[0].right_mask), updated_right_mask.coefs);

// assert the other iris in db is not updated
assert_eq!(cast_u8_to_u16(&got[1].left_code), iris.left_code);
assert_eq!(cast_u8_to_u16(&got[1].left_mask), iris.left_mask);
assert_eq!(cast_u8_to_u16(&got[1].right_code), iris.right_code);
assert_eq!(cast_u8_to_u16(&got[1].right_mask), iris.right_mask);
assert_eq!(cast_u8_to_u16(&got[1].left_code), iris2.left_code);
assert_eq!(cast_u8_to_u16(&got[1].left_mask), iris2.left_mask);
assert_eq!(cast_u8_to_u16(&got[1].right_code), iris2.right_code);
assert_eq!(cast_u8_to_u16(&got[1].right_mask), iris2.right_mask);

cleanup(&store, &schema_name).await?;
Ok(())
Expand Down
8 changes: 0 additions & 8 deletions iris-mpc-upgrade/src/bin/tcp_upgrade_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ async fn main() -> eyre::Result<()> {
client_stream1.write_u8(FINAL_BATCH_SUCCESSFUL_ACK).await?;
tracing::info!("Sent final ACK to client1");

tracing::info!("Updating iris id sequence");
sink.update_iris_id_sequence().await?;
tracing::info!("Iris id sequence updated");

Ok(())
}

Expand Down Expand Up @@ -252,8 +248,4 @@ impl NewIrisShareSink for IrisShareDbSink {
}
}
}

async fn update_iris_id_sequence(&self) -> eyre::Result<()> {
self.store.update_iris_id_sequence().await
}
}
6 changes: 0 additions & 6 deletions iris-mpc-upgrade/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ pub trait NewIrisShareSink {
code_share: &[u16; IRIS_CODE_LENGTH],
mask_share: &[u16; MASK_CODE_LENGTH],
) -> Result<()>;

async fn update_iris_id_sequence(&self) -> Result<()>;
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -83,10 +81,6 @@ impl NewIrisShareSink for IrisShareTestFileSink {
file.flush()?;
Ok(())
}

async fn update_iris_id_sequence(&self) -> Result<()> {
Ok(())
}
}

#[derive(Clone)]
Expand Down
51 changes: 15 additions & 36 deletions iris-mpc/src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,36 +660,19 @@ async fn server_main(config: Config) -> eyre::Result<()> {
tracing::info!("Size of the database after init: {}", store_len);

// Check if the sequence id is consistent with the number of irises
let iris_sequence_id = store.get_irises_sequence_id().await?;
if iris_sequence_id != store_len {
tracing::warn!(
"Detected inconsistent iris sequence id {} != {}, resetting...",
iris_sequence_id,
let max_serial_id = store.get_max_serial_id().await?;
if max_serial_id != store_len {
tracing::error!(
"Detected inconsistency between max serial id {} and db size {}.",
max_serial_id,
store_len
);

// Reset the sequence id
store.set_irises_sequence_id(store_len).await?;

// Fetch again and check that the sequence id is consistent now
let store_len = store.count_irises().await?;
let iris_sequence_id = store.get_irises_sequence_id().await?;

// If db is empty, we set the sequence id to 1 with advance_nextval false
let empty_db_sequence_ok = store_len == 0 && iris_sequence_id == 1;

if iris_sequence_id != store_len && !empty_db_sequence_ok {
tracing::error!(
"Iris sequence id is still inconsistent: {} != {}",
iris_sequence_id,
store_len
);
eyre::bail!(
"Iris sequence id is still inconsistent: {} != {}",
iris_sequence_id,
store_len
);
}
eyre::bail!(
"Detected inconsistency between max serial id {} and db size {}.",
max_serial_id,
store_len
);
}

if store_len > config.max_db_size {
Expand Down Expand Up @@ -912,16 +895,18 @@ async fn server_main(config: Config) -> eyre::Result<()> {
.collect::<eyre::Result<Vec<_>>>()?;

// Insert non-matching queries into the persistent store.
let (memory_serial_ids, codes_and_masks): (Vec<u32>, Vec<StoredIrisRef>) = matches
let (memory_serial_ids, codes_and_masks): (Vec<i64>, Vec<StoredIrisRef>) = matches
.iter()
.enumerate()
.filter_map(
// Find the indices of non-matching queries in the batch.
|(query_idx, is_match)| if !is_match { Some(query_idx) } else { None },
)
.map(|query_idx| {
let serial_id = (merged_results[query_idx] + 1) as i64;
// Get the original vectors from `receive_batch`.
(merged_results[query_idx] + 1, StoredIrisRef {
(serial_id, StoredIrisRef {
id: serial_id,
left_code: &store_left.code[query_idx].coefs[..],
left_mask: &store_left.mask[query_idx].coefs[..],
right_code: &store_right.code[query_idx].coefs[..],
Expand All @@ -937,13 +922,7 @@ async fn server_main(config: Config) -> eyre::Result<()> {
.await?;

if !codes_and_masks.is_empty() && !config_bg.disable_persistence {
let db_serial_ids = store_bg
.insert_irises(&mut tx, &codes_and_masks)
.await
.wrap_err("failed to persist queries")?
.iter()
.map(|&x| x as u32)
.collect::<Vec<_>>();
let db_serial_ids = store_bg.insert_irises(&mut tx, &codes_and_masks).await?;

// Check if the serial_ids match between memory and db.
if memory_serial_ids != db_serial_ids {
Expand Down

0 comments on commit e1ff515

Please sign in to comment.