Skip to content

Commit

Permalink
convert to binary
Browse files Browse the repository at this point in the history
  • Loading branch information
eaypek-tfh committed Dec 9, 2024
1 parent 29bf6d5 commit 2a5c7a2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 74 deletions.
39 changes: 39 additions & 0 deletions iris-mpc-store/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![feature(int_roundings)]

mod s3_importer;

use bytemuck::cast_slice;
Expand Down Expand Up @@ -74,6 +76,43 @@ impl StoredIris {
pub fn id(&self) -> i64 {
self.id
}

pub fn from_bytes(bytes: &[u8]) -> Result<Self, eyre::Error> {
let mut cursor = 0;

// Helper closure to extract a slice of a given size
let extract_slice =
|bytes: &[u8], cursor: &mut usize, size: usize| -> Result<Vec<u8>, eyre::Error> {
if *cursor + size > bytes.len() {
return Err(eyre!("Exceeded total bytes while extracting slice",));
}
let slice = &bytes[*cursor..*cursor + size];
*cursor += size;
Ok(slice.to_vec())
};

// Parse `id` (i64)
let id_bytes = extract_slice(bytes, &mut cursor, 4)?;
let id = i64::from_be_bytes(
id_bytes
.try_into()
.map_err(|_| eyre!("Failed to convert id bytes to i64"))?,
);

// parse codes and masks
let left_code = extract_slice(bytes, &mut cursor, 25_600)?;
let left_mask = extract_slice(bytes, &mut cursor, 12_800)?;
let right_code = extract_slice(bytes, &mut cursor, 25_600)?;
let right_mask = extract_slice(bytes, &mut cursor, 12_800)?;

Ok(StoredIris {
id,
left_code,
left_mask,
right_code,
right_mask,
})
}
}

#[derive(Clone)]
Expand Down
93 changes: 19 additions & 74 deletions iris-mpc-store/src/s3_importer.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
use crate::StoredIris;
use async_trait::async_trait;
use aws_sdk_s3::Client;
use bytes::Bytes;
use futures::{stream, Stream, StreamExt};
use iris_mpc_common::{IRIS_CODE_LENGTH, MASK_CODE_LENGTH};
use rayon::{iter::ParallelIterator, prelude::ParallelBridge};
use serde::Deserialize;
use std::{io::Cursor, mem, pin::Pin, sync::Arc, time, time::Instant};
use std::{mem, pin::Pin, sync::Arc, time::Instant};
use tokio::task;

const SINGLE_ELEMENT_SIZE: usize = IRIS_CODE_LENGTH * mem::size_of::<u16>() * 2
+ MASK_CODE_LENGTH * mem::size_of::<u16>() * 2
+ mem::size_of::<u32>(); // 75 KB
const CSV_BUFFER_CAPACITY: usize = SINGLE_ELEMENT_SIZE * 10;

#[async_trait]
pub trait ObjectStore: Send + Sync + 'static {
async fn get_object(&self, key: &str) -> eyre::Result<Bytes>;
async fn get_object(&self, key: &str) -> eyre::Result<Vec<u8>>;
async fn list_objects(&self) -> eyre::Result<Vec<String>>;
}

Expand All @@ -33,7 +29,7 @@ impl S3Store {

#[async_trait]
impl ObjectStore for S3Store {
async fn get_object(&self, key: &str) -> eyre::Result<Bytes> {
async fn get_object(&self, key: &str) -> eyre::Result<Vec<u8>> {
let result = self
.client
.get_object()
Expand All @@ -43,7 +39,7 @@ impl ObjectStore for S3Store {
.await?;

let data = result.body.collect().await?;
Ok(data.into_bytes())
Ok(data.to_vec())
}

async fn list_objects(&self) -> eyre::Result<Vec<String>> {
Expand Down Expand Up @@ -76,24 +72,6 @@ impl ObjectStore for S3Store {
}
}

#[derive(Debug, Deserialize)]
struct CsvIrisRecord {
id: String,
left_code: String,
left_mask: String,
right_code: String,
right_mask: String,
}

fn hex_to_bytes(hex: &str, byte_len: usize) -> eyre::Result<Vec<u8>> {
if hex.is_empty() {
return Ok(vec![]);
}
let mut bytes = vec![0; byte_len];
hex::decode_to_slice(hex, &mut bytes)?;
Ok(bytes)
}

pub async fn last_snapshot_timestamp(store: &impl ObjectStore) -> eyre::Result<i64> {
store
.list_objects()
Expand All @@ -115,12 +93,10 @@ pub async fn fetch_and_parse_chunks(
concurrency: usize,
) -> Pin<Box<dyn Stream<Item = eyre::Result<StoredIris>> + Send + '_>> {
let chunks = store.list_objects().await.unwrap();
let mut total_get_object_time = time::Duration::from_secs(0);
let mut total_csv_parse_time = time::Duration::from_secs(0);

let result_stream = stream::iter(chunks)
.filter_map(|chunk| async move {
if chunk.ends_with(".csv") {
if chunk.ends_with(".bin") {
tracing::info!("Processing chunk: {}", chunk);
Some(chunk)
} else {
Expand All @@ -132,50 +108,25 @@ pub async fn fetch_and_parse_chunks(
let result = store.get_object(&chunk).await?;
let get_object_time = now.elapsed();
tracing::info!("Got chunk object: {} in {:?}", chunk, get_object_time,);
total_get_object_time += get_object_time;

now = Instant::now();
let task = task::spawn_blocking(move || {
let cursor = Cursor::new(result);
let reader = csv::ReaderBuilder::new()
.has_headers(true)
.buffer_capacity(CSV_BUFFER_CAPACITY)
.from_reader(cursor);

let records: Vec<eyre::Result<StoredIris>> = reader
.into_deserialize()
.par_bridge()
.map(|r: Result<CsvIrisRecord, _>| {
let raw = r.map_err(|e| eyre::eyre!("CSV parse error: {}", e))?;

Ok(StoredIris {
id: raw.id.parse()?,
left_code: hex_to_bytes(
&raw.left_code,
IRIS_CODE_LENGTH * mem::size_of::<u16>(),
)?,
left_mask: hex_to_bytes(
&raw.left_mask,
MASK_CODE_LENGTH * mem::size_of::<u16>(),
)?,
right_code: hex_to_bytes(
&raw.right_code,
IRIS_CODE_LENGTH * mem::size_of::<u16>(),
)?,
right_mask: hex_to_bytes(
&raw.right_mask,
MASK_CODE_LENGTH * mem::size_of::<u16>(),
)?,
})
})
.collect();
let n_records = result.len().div_floor(SINGLE_ELEMENT_SIZE);

let mut records = Vec::with_capacity(n_records);
for i in 0..n_records {
let start = i * SINGLE_ELEMENT_SIZE;
let end = (i + 1) * SINGLE_ELEMENT_SIZE;
let chunk = &result[start..end];
let iris = StoredIris::from_bytes(chunk);
records.push(iris);
}

Ok::<_, eyre::Error>(stream::iter(records))
})
.await?;
let csv_parse_time = now.elapsed();
tracing::info!("Parsed csv chunk: {} in {:?}", chunk, csv_parse_time,);
total_csv_parse_time += csv_parse_time;
let parse_time = now.elapsed();
tracing::info!("Parsed chunk: {} in {:?}", chunk, parse_time,);
task
})
.buffer_unordered(concurrency)
Expand All @@ -184,11 +135,6 @@ pub async fn fetch_and_parse_chunks(
Err(e) => stream::once(async move { Err(e) }).boxed(),
})
.boxed();
tracing::info!(
"fetch_and_parse_chunks summary => Total get_object time: {:?}, Total csv parse time: {:?}",
total_get_object_time,
total_csv_parse_time,
);

result_stream
}
Expand Down Expand Up @@ -235,12 +181,11 @@ mod tests {

#[async_trait]
impl ObjectStore for MockStore {
async fn get_object(&self, key: &str) -> eyre::Result<Bytes> {
async fn get_object(&self, key: &str) -> eyre::Result<Vec<u8>> {
self.objects
.get(key)
.cloned()
.map(Bytes::from)
.ok_or_else(|| eyre::eyre!("Object not found: {}", key))
.cloned()
}

async fn list_objects(&self) -> eyre::Result<Vec<String>> {
Expand Down

0 comments on commit 2a5c7a2

Please sign in to comment.