diff --git a/iris-mpc-cpu/src/network/local.rs b/iris-mpc-cpu/src/network/local.rs index 91269795f..210cd0deb 100644 --- a/iris-mpc-cpu/src/network/local.rs +++ b/iris-mpc-cpu/src/network/local.rs @@ -103,8 +103,7 @@ impl Networking for LocalNetworking { #[cfg(test)] mod tests { use super::*; - use crate::network::value::NetworkValue; - use std::num::Wrapping; + use crate::{network::value::NetworkValue, shares::ring_impl::RingElement}; #[tokio::test] async fn test_network_send_receive() { @@ -118,11 +117,11 @@ mod tests { let recv = bob.receive(&"alice".into(), &1_u64.into()).await; assert_eq!( NetworkValue::from_network(recv).unwrap(), - NetworkValue::Ring16(Wrapping::(777)) + NetworkValue::RingElement16(RingElement(777)) ); }); let task2 = tokio::spawn(async move { - let value = NetworkValue::Ring16(Wrapping::(777)); + let value = NetworkValue::RingElement16(RingElement(777)); alice .send(value.to_network(), &"bob".into(), &1_u64.into()) .await diff --git a/iris-mpc-cpu/src/network/value.rs b/iris-mpc-cpu/src/network/value.rs index 43be3b76b..634fbb840 100644 --- a/iris-mpc-cpu/src/network/value.rs +++ b/iris-mpc-cpu/src/network/value.rs @@ -1,13 +1,13 @@ use crate::shares::{bit::Bit, ring_impl::RingElement}; use eyre::eyre; -use serde::{Deserialize, Serialize}; + +/// Size of a PRF key in bytes +const PRF_KEY_SIZE: usize = 16; /// Value sent over the network -#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] +#[derive(PartialEq, Clone, Debug)] pub enum NetworkValue { - PrfKey([u8; 16]), - Ring16(std::num::Wrapping), - Ring32(std::num::Wrapping), + PrfKey([u8; PRF_KEY_SIZE]), RingElementBit(RingElement), RingElement16(RingElement), RingElement32(RingElement), @@ -18,20 +18,250 @@ pub enum NetworkValue { } impl NetworkValue { + fn get_descriptor_byte(&self) -> u8 { + match self { + NetworkValue::PrfKey(_) => 0x01, + NetworkValue::RingElementBit(bit) => { + if bit.convert().convert() { + 0x12 + } else { + 0x02 + } + } + NetworkValue::RingElement16(_) => 0x03, + NetworkValue::RingElement32(_) => 0x04, + NetworkValue::RingElement64(_) => 0x05, + NetworkValue::VecRing16(_) => 0x06, + NetworkValue::VecRing32(_) => 0x07, + NetworkValue::VecRing64(_) => 0x08, + } + } + + fn byte_len(&self) -> usize { + match self { + NetworkValue::PrfKey(_) => 1 + PRF_KEY_SIZE, + NetworkValue::RingElementBit(_) => 1, + NetworkValue::RingElement16(_) => 3, + NetworkValue::RingElement32(_) => 5, + NetworkValue::RingElement64(_) => 9, + NetworkValue::VecRing16(v) => 5 + 2 * v.len(), + NetworkValue::VecRing32(v) => 5 + 4 * v.len(), + NetworkValue::VecRing64(v) => 5 + 8 * v.len(), + } + } + + fn to_network_inner(&self, res: &mut Vec) { + res.push(self.get_descriptor_byte()); + + match self { + NetworkValue::PrfKey(key) => res.extend_from_slice(key), + NetworkValue::RingElementBit(_) => { + // Do nothing, the descriptor byte already contains the bit + // value + } + NetworkValue::RingElement16(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::RingElement32(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::RingElement64(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::VecRing16(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + NetworkValue::VecRing32(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + NetworkValue::VecRing64(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + } + } + pub fn to_network(&self) -> Vec { - bincode::serialize(self).unwrap() + let mut res = Vec::with_capacity(self.byte_len()); + self.to_network_inner(&mut res); + res } pub fn from_network(serialized: eyre::Result>) -> eyre::Result { - bincode::deserialize::(&serialized?).map_err(|_e| eyre!("Failed to parse value")) + let serialized = serialized?; + let descriptor = serialized[0]; + match descriptor { + 0x01 => { + if serialized.len() != 1 + PRF_KEY_SIZE { + return Err(eyre!("Invalid length for PrfKey")); + } + Ok(NetworkValue::PrfKey(<[u8; PRF_KEY_SIZE]>::try_from( + &serialized[1..1 + PRF_KEY_SIZE], + )?)) + } + 0x02 | 0x12 => { + if serialized.len() != 1 { + return Err(eyre!("Invalid length for RingElementBit")); + } + let bit = if descriptor == 0x12 { + Bit::new(true) + } else { + Bit::new(false) + }; + Ok(NetworkValue::RingElementBit(RingElement(bit))) + } + 0x03 => { + if serialized.len() != 3 { + return Err(eyre!("Invalid length for RingElement16")); + } + Ok(NetworkValue::RingElement16(RingElement( + u16::from_le_bytes(<[u8; 2]>::try_from(&serialized[1..3])?), + ))) + } + 0x04 => { + if serialized.len() != 5 { + return Err(eyre!("Invalid length for RingElement32")); + } + Ok(NetworkValue::RingElement32(RingElement( + u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?), + ))) + } + 0x05 => { + if serialized.len() != 9 { + return Err(eyre!("Invalid length for RingElement64")); + } + Ok(NetworkValue::RingElement64(RingElement( + u64::from_le_bytes(<[u8; 8]>::try_from(&serialized[1..9])?), + ))) + } + 0x06 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing16: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 2 * len { + return Err(eyre!("Invalid length for VecRing16")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u16::from_le_bytes(<[u8; 2]>::try_from( + &serialized[5 + 2 * i..5 + 2 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing16(res)) + } + 0x07 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing32: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 4 * len { + return Err(eyre!("Invalid length for VecRing32")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[5 + 4 * i..5 + 4 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing32(res)) + } + 0x08 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing64: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 8 * len { + return Err(eyre!("Invalid length for VecRing64")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u64::from_le_bytes(<[u8; 8]>::try_from( + &serialized[5 + 8 * i..5 + 8 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing64(res)) + } + _ => Err(eyre!("Invalid network value type")), + } } pub fn vec_to_network(values: &Vec) -> Vec { - bincode::serialize(&values).unwrap() + // 4 extra bytes for the length of the vector + let len = values.iter().map(|v| v.byte_len()).sum::() + 4; + let mut res = Vec::with_capacity(len); + res.extend_from_slice(&(values.len() as u32).to_le_bytes()); + for value in values { + value.to_network_inner(&mut res); + } + res } pub fn vec_from_network(serialized: eyre::Result>) -> eyre::Result> { - bincode::deserialize::>(&serialized?).map_err(|_e| eyre!("Failed to parse value")) + let serialized = serialized?; + if serialized.len() < 4 { + return Err(eyre!("Can't parse vector length")); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[0..4])?) as usize; + let mut res = Vec::with_capacity(len); + let mut offset = 4; + for _ in 0..len { + let descriptor = serialized[offset]; + let value_len = match descriptor { + 0x01 => 1 + PRF_KEY_SIZE, + 0x02 | 0x12 => 1, // RingElementBit + 0x03 => 3, // RingElement16 + 0x04 => 5, // RingElement32 + 0x05 => 9, // RingElement64 + 0x06 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing16: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 2 * len + } + 0x07 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing32: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 4 * len + } + 0x08 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing64: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 8 * len + } + _ => return Err(eyre!("Invalid network value type")), + }; + res.push(NetworkValue::from_network(Ok(serialized + [offset..offset + value_len] + .to_vec()))?); + offset += value_len; + } + Ok(res) } } diff --git a/iris-mpc-cpu/src/protocol/binary.rs b/iris-mpc-cpu/src/protocol/binary.rs index 49ddccd88..75e05de79 100644 --- a/iris-mpc-cpu/src/protocol/binary.rs +++ b/iris-mpc-cpu/src/protocol/binary.rs @@ -98,12 +98,13 @@ where let network = session.network().clone(); let sid = session.session_id(); let message = shares_a.clone(); + let message = if message.len() == 1 { + NetworkValue::RingElement64(message[0]) + } else { + NetworkValue::VecRing64(message) + }; network - .send( - NetworkValue::VecRing64(message).to_network(), - &next_party, - &sid, - ) + .send(message.to_network(), &next_party, &sid) .await?; Ok(shares_a) } @@ -118,6 +119,7 @@ pub(crate) async fn and_many_receive( let shares_b = { let serialized_other_share = network.receive(&prev_party, &sid).await; match NetworkValue::from_network(serialized_other_share) { + Ok(NetworkValue::RingElement64(message)) => Ok(vec![message]), Ok(NetworkValue::VecRing64(message)) => Ok(message), _ => Err(eyre!("Error in receiving in and_many operation")), }