Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better serialization of binary values and singletons #844

Merged
merged 8 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions iris-mpc-cpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async-stream = "0.3.6"
async-trait = "~0.1"
backoff = {version="0.4.0", features = ["tokio"]}
bincode.workspace = true
bitcode = { version="0.6.3", features = ["serde"] }
iliailia marked this conversation as resolved.
Show resolved Hide resolved
bytes = "1.7"
bytemuck.workspace = true
clap.workspace = true
Expand Down
7 changes: 3 additions & 4 deletions iris-mpc-cpu/src/network/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,7 @@ impl Networking for LocalNetworking {
#[cfg(test)]
mod tests {
use super::*;
use crate::network::value::NetworkValue;
use std::num::Wrapping;
use crate::{network::value::NetworkValue, shares::ring_impl::RingElement};

#[tokio::test]
async fn test_network_send_receive() {
Expand All @@ -118,11 +117,11 @@ mod tests {
let recv = bob.receive(&"alice".into(), &1_u64.into()).await;
assert_eq!(
NetworkValue::from_network(recv).unwrap(),
NetworkValue::Ring16(Wrapping::<u16>(777))
NetworkValue::RingElement16(RingElement(777))
);
});
let task2 = tokio::spawn(async move {
let value = NetworkValue::Ring16(Wrapping::<u16>(777));
let value = NetworkValue::RingElement16(RingElement(777));
alice
.send(value.to_network(), &"bob".into(), &1_u64.into())
.await
Expand Down
245 changes: 237 additions & 8 deletions iris-mpc-cpu/src/network/value.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use crate::shares::{bit::Bit, ring_impl::RingElement};
use eyre::eyre;
use serde::{Deserialize, Serialize};

const PRF_KEY_SIZE: usize = 16;
iliailia marked this conversation as resolved.
Show resolved Hide resolved

/// Value sent over the network
#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)]
#[derive(PartialEq, Clone, Debug)]
pub enum NetworkValue {
dkales marked this conversation as resolved.
Show resolved Hide resolved
PrfKey([u8; 16]),
iliailia marked this conversation as resolved.
Show resolved Hide resolved
Ring16(std::num::Wrapping<u16>),
Ring32(std::num::Wrapping<u32>),
RingElementBit(RingElement<Bit>),
RingElement16(RingElement<u16>),
RingElement32(RingElement<u32>),
Expand All @@ -18,20 +17,250 @@ pub enum NetworkValue {
}

impl NetworkValue {
fn get_descriptor_byte(&self) -> u8 {
match self {
NetworkValue::PrfKey(_) => 0x01,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these descriptor bytes start at 1 because the network byte stream looks for 0 to indicate EOF or something similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking to leave it for some special message, e.g., "abort"

NetworkValue::RingElementBit(bit) => {
if bit.convert().convert() {
0x12
} else {
0x02
}
}
NetworkValue::RingElement16(_) => 0x03,
NetworkValue::RingElement32(_) => 0x04,
NetworkValue::RingElement64(_) => 0x05,
NetworkValue::VecRing16(_) => 0x06,
NetworkValue::VecRing32(_) => 0x07,
NetworkValue::VecRing64(_) => 0x08,
}
}

fn byte_len(&self) -> usize {
match self {
NetworkValue::PrfKey(_) => 1 + PRF_KEY_SIZE,
NetworkValue::RingElementBit(_) => 1,
NetworkValue::RingElement16(_) => 3,
NetworkValue::RingElement32(_) => 5,
NetworkValue::RingElement64(_) => 9,
NetworkValue::VecRing16(v) => 5 + 2 * v.len(),
NetworkValue::VecRing32(v) => 5 + 4 * v.len(),
NetworkValue::VecRing64(v) => 5 + 8 * v.len(),
}
}

fn to_network_inner(&self, res: &mut Vec<u8>) {
res.push(self.get_descriptor_byte());

match self {
NetworkValue::PrfKey(key) => res.extend_from_slice(key),
NetworkValue::RingElementBit(_) => {
// Do nothing, the descriptor byte already contains the bit
// value
}
NetworkValue::RingElement16(x) => res.extend_from_slice(&x.convert().to_le_bytes()),
NetworkValue::RingElement32(x) => res.extend_from_slice(&x.convert().to_le_bytes()),
NetworkValue::RingElement64(x) => res.extend_from_slice(&x.convert().to_le_bytes()),
NetworkValue::VecRing16(v) => {
res.extend_from_slice(&(v.len() as u32).to_le_bytes());
for x in v {
res.extend_from_slice(&x.convert().to_le_bytes());
}
}
NetworkValue::VecRing32(v) => {
res.extend_from_slice(&(v.len() as u32).to_le_bytes());
for x in v {
res.extend_from_slice(&x.convert().to_le_bytes());
}
}
NetworkValue::VecRing64(v) => {
res.extend_from_slice(&(v.len() as u32).to_le_bytes());
for x in v {
res.extend_from_slice(&x.convert().to_le_bytes());
}
}
}
}

pub fn to_network(&self) -> Vec<u8> {
bincode::serialize(self).unwrap()
let mut res = Vec::with_capacity(self.byte_len());
self.to_network_inner(&mut res);
res
}

pub fn from_network(serialized: eyre::Result<Vec<u8>>) -> eyre::Result<Self> {
bincode::deserialize::<Self>(&serialized?).map_err(|_e| eyre!("Failed to parse value"))
let serialized = serialized?;
let descriptor = serialized[0];
match descriptor {
0x01 => {
if serialized.len() != 1 + PRF_KEY_SIZE {
return Err(eyre!("Invalid length for PrfKey"));
}
Ok(NetworkValue::PrfKey(<[u8; PRF_KEY_SIZE]>::try_from(
&serialized[1..1 + PRF_KEY_SIZE],
)?))
}
0x02 | 0x12 => {
if serialized.len() != 1 {
return Err(eyre!("Invalid length for RingElementBit"));
}
let bit = if descriptor == 0x12 {
Bit::new(true)
} else {
Bit::new(false)
};
Ok(NetworkValue::RingElementBit(RingElement(bit)))
}
0x03 => {
if serialized.len() != 3 {
return Err(eyre!("Invalid length for RingElement16"));
}
Ok(NetworkValue::RingElement16(RingElement(
u16::from_le_bytes(<[u8; 2]>::try_from(&serialized[1..3])?),
)))
}
0x04 => {
if serialized.len() != 5 {
return Err(eyre!("Invalid length for RingElement32"));
}
Ok(NetworkValue::RingElement32(RingElement(
u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?),
)))
}
0x05 => {
if serialized.len() != 9 {
return Err(eyre!("Invalid length for RingElement64"));
}
Ok(NetworkValue::RingElement64(RingElement(
u64::from_le_bytes(<[u8; 8]>::try_from(&serialized[1..9])?),
)))
}
0x06 => {
if serialized.len() < 5 {
return Err(eyre!(
"Invalid length for VecRing16: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize;
if serialized.len() != 5 + 2 * len {
return Err(eyre!("Invalid length for VecRing16"));
}
let mut res = Vec::with_capacity(len);
for i in 0..len {
res.push(RingElement(u16::from_le_bytes(<[u8; 2]>::try_from(
&serialized[5 + 2 * i..5 + 2 * (i + 1)],
)?)));
}
Ok(NetworkValue::VecRing16(res))
}
0x07 => {
if serialized.len() < 5 {
return Err(eyre!(
"Invalid length for VecRing32: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize;
if serialized.len() != 5 + 4 * len {
return Err(eyre!("Invalid length for VecRing32"));
}
let mut res = Vec::with_capacity(len);
for i in 0..len {
res.push(RingElement(u32::from_le_bytes(<[u8; 4]>::try_from(
&serialized[5 + 4 * i..5 + 4 * (i + 1)],
)?)));
}
Ok(NetworkValue::VecRing32(res))
}
0x08 => {
if serialized.len() < 5 {
return Err(eyre!(
"Invalid length for VecRing64: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize;
if serialized.len() != 5 + 8 * len {
return Err(eyre!("Invalid length for VecRing64"));
}
let mut res = Vec::with_capacity(len);
for i in 0..len {
res.push(RingElement(u64::from_le_bytes(<[u8; 8]>::try_from(
&serialized[5 + 8 * i..5 + 8 * (i + 1)],
)?)));
}
Ok(NetworkValue::VecRing64(res))
}
_ => Err(eyre!("Invalid network value type")),
}
}

pub fn vec_to_network(values: &Vec<Self>) -> Vec<u8> {
bincode::serialize(&values).unwrap()
// 4 extra bytes for the length of the vector
let len = values.iter().map(|v| v.byte_len()).sum::<usize>() + 4;
let mut res = Vec::with_capacity(len);
res.extend_from_slice(&(values.len() as u32).to_le_bytes());
for value in values {
value.to_network_inner(&mut res);
}
res
}

pub fn vec_from_network(serialized: eyre::Result<Vec<u8>>) -> eyre::Result<Vec<Self>> {
bincode::deserialize::<Vec<Self>>(&serialized?).map_err(|_e| eyre!("Failed to parse value"))
let serialized = serialized?;
if serialized.len() < 4 {
return Err(eyre!("Can't parse vector length"));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[0..4])?) as usize;
let mut res = Vec::with_capacity(len);
let mut offset = 4;
for _ in 0..len {
let descriptor = serialized[offset];
let value_len = match descriptor {
0x01 => 1 + PRF_KEY_SIZE,
0x02 | 0x12 => 1, // RingElementBit
0x03 => 3, // RingElement16
0x04 => 5, // RingElement32
0x05 => 9, // RingElement64
0x06 => {
if serialized.len() < offset + 5 {
return Err(eyre!(
"Invalid length for VecRing16: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(
&serialized[offset + 1..offset + 5],
)?) as usize;
5 + 2 * len
}
0x07 => {
if serialized.len() < offset + 5 {
return Err(eyre!(
"Invalid length for VecRing32: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(
&serialized[offset + 1..offset + 5],
)?) as usize;
5 + 4 * len
}
0x08 => {
if serialized.len() < offset + 5 {
return Err(eyre!(
"Invalid length for VecRing64: can't parse vector length"
));
}
let len = u32::from_le_bytes(<[u8; 4]>::try_from(
&serialized[offset + 1..offset + 5],
)?) as usize;
5 + 8 * len
}
_ => return Err(eyre!("Invalid network value type")),
};
res.push(NetworkValue::from_network(Ok(serialized
[offset..offset + value_len]
.to_vec()))?);
offset += value_len;
}
Ok(res)
}
}

Expand Down
12 changes: 7 additions & 5 deletions iris-mpc-cpu/src/protocol/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ where
let network = session.network().clone();
let sid = session.session_id();
let message = shares_a.clone();
let message = if message.len() == 1 {
NetworkValue::RingElement64(message[0])
} else {
NetworkValue::VecRing64(message)
};
network
.send(
NetworkValue::VecRing64(message).to_network(),
&next_party,
&sid,
)
.send(message.to_network(), &next_party, &sid)
.await?;
Ok(shares_a)
}
Expand All @@ -118,6 +119,7 @@ pub(crate) async fn and_many_receive(
let shares_b = {
let serialized_other_share = network.receive(&prev_party, &sid).await;
match NetworkValue::from_network(serialized_other_share) {
Ok(NetworkValue::RingElement64(message)) => Ok(vec![message]),
Ok(NetworkValue::VecRing64(message)) => Ok(message),
_ => Err(eyre!("Error in receiving in and_many operation")),
}
Expand Down
Loading