diff --git a/iris-mpc-cpu/src/network/value.rs b/iris-mpc-cpu/src/network/value.rs index 22079c51c..3a5c02c98 100644 --- a/iris-mpc-cpu/src/network/value.rs +++ b/iris-mpc-cpu/src/network/value.rs @@ -1,9 +1,10 @@ use crate::shares::{bit::Bit, ring_impl::RingElement}; use eyre::eyre; -use serde::{Deserialize, Serialize}; + +const PRF_KEY_SIZE: usize = 16; /// Value sent over the network -#[derive(Serialize, Deserialize, PartialEq, Clone, Debug)] +#[derive(PartialEq, Clone, Debug)] pub enum NetworkValue { PrfKey([u8; 16]), RingElementBit(RingElement), @@ -16,20 +17,231 @@ pub enum NetworkValue { } impl NetworkValue { + fn get_descriptor_byte(&self) -> u8 { + match self { + NetworkValue::PrfKey(_) => 0x01, + NetworkValue::RingElementBit(bit) => { + if bit.convert().convert() { + 0x12 + } else { + 0x02 + } + } + NetworkValue::RingElement16(_) => 0x03, + NetworkValue::RingElement32(_) => 0x04, + NetworkValue::RingElement64(_) => 0x05, + NetworkValue::VecRing16(_) => 0x06, + NetworkValue::VecRing32(_) => 0x07, + NetworkValue::VecRing64(_) => 0x08, + } + } + pub fn to_network(&self) -> Vec { - bitcode::serialize(self).unwrap() + let mut res = vec![]; + res.push(self.get_descriptor_byte()); + + match self { + NetworkValue::PrfKey(key) => res.extend_from_slice(key), + NetworkValue::RingElementBit(_) => { + // Do nothing, the descriptor byte already contains the bit + // value + } + NetworkValue::RingElement16(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::RingElement32(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::RingElement64(x) => res.extend_from_slice(&x.convert().to_le_bytes()), + NetworkValue::VecRing16(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + NetworkValue::VecRing32(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + NetworkValue::VecRing64(v) => { + res.extend_from_slice(&(v.len() as u32).to_le_bytes()); + for x in v { + res.extend_from_slice(&x.convert().to_le_bytes()); + } + } + } + res } pub fn from_network(serialized: eyre::Result>) -> eyre::Result { - bitcode::deserialize::(&serialized?).map_err(|_e| eyre!("Failed to parse value")) + let serialized = serialized?; + let descriptor = serialized[0]; + match descriptor { + 0x01 => { + if serialized.len() != 1 + PRF_KEY_SIZE { + return Err(eyre!("Invalid length for PrfKey")); + } + Ok(NetworkValue::PrfKey(<[u8; PRF_KEY_SIZE]>::try_from( + &serialized[1..1 + PRF_KEY_SIZE], + )?)) + } + 0x02 | 0x12 => { + if serialized.len() != 1 { + return Err(eyre!("Invalid length for RingElementBit")); + } + let bit = if descriptor == 0x12 { + Bit::new(true) + } else { + Bit::new(false) + }; + Ok(NetworkValue::RingElementBit(RingElement(bit))) + } + 0x03 => { + if serialized.len() != 3 { + return Err(eyre!("Invalid length for RingElement16")); + } + Ok(NetworkValue::RingElement16(RingElement( + u16::from_le_bytes(<[u8; 2]>::try_from(&serialized[1..3])?), + ))) + } + 0x04 => { + if serialized.len() != 5 { + return Err(eyre!("Invalid length for RingElement32")); + } + Ok(NetworkValue::RingElement32(RingElement( + u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?), + ))) + } + 0x05 => { + if serialized.len() != 9 { + return Err(eyre!("Invalid length for RingElement64")); + } + Ok(NetworkValue::RingElement64(RingElement( + u64::from_le_bytes(<[u8; 8]>::try_from(&serialized[1..9])?), + ))) + } + 0x06 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing16: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 2 * len { + return Err(eyre!("Invalid length for VecRing16")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u16::from_le_bytes(<[u8; 2]>::try_from( + &serialized[5 + 2 * i..5 + 2 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing16(res)) + } + 0x07 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing32: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 4 * len { + return Err(eyre!("Invalid length for VecRing32")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[5 + 4 * i..5 + 4 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing32(res)) + } + 0x08 => { + if serialized.len() < 5 { + return Err(eyre!( + "Invalid length for VecRing64: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[1..5])?) as usize; + if serialized.len() != 5 + 8 * len { + return Err(eyre!("Invalid length for VecRing64")); + } + let mut res = Vec::with_capacity(len); + for i in 0..len { + res.push(RingElement(u64::from_le_bytes(<[u8; 8]>::try_from( + &serialized[5 + 8 * i..5 + 8 * (i + 1)], + )?))); + } + Ok(NetworkValue::VecRing64(res)) + } + _ => Err(eyre!("Invalid network value type")), + } } pub fn vec_to_network(values: &Vec) -> Vec { - bitcode::serialize(&values).unwrap() + let mut res = vec![]; + res.extend_from_slice(&(values.len() as u32).to_le_bytes()); + for value in values { + res.extend_from_slice(&value.to_network()); + } + res } pub fn vec_from_network(serialized: eyre::Result>) -> eyre::Result> { - bitcode::deserialize::>(&serialized?).map_err(|_e| eyre!("Failed to parse value")) + let serialized = serialized?; + if serialized.len() < 4 { + return Err(eyre!("Can't parse vector length")); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from(&serialized[0..4])?) as usize; + let mut res = Vec::with_capacity(len); + let mut offset = 4; + for _ in 0..len { + let descriptor = serialized[offset]; + let value_len = match descriptor { + 0x01 => 1 + PRF_KEY_SIZE, + 0x02 | 0x12 => 1, // RingElementBit + 0x03 => 3, // RingElement16 + 0x04 => 5, // RingElement32 + 0x05 => 9, // RingElement64 + 0x06 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing16: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 2 * len + } + 0x07 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing32: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 4 * len + } + 0x08 => { + if serialized.len() < offset + 5 { + return Err(eyre!( + "Invalid length for VecRing64: can't parse vector length" + )); + } + let len = u32::from_le_bytes(<[u8; 4]>::try_from( + &serialized[offset + 1..offset + 5], + )?) as usize; + 5 + 8 * len + } + _ => return Err(eyre!("Invalid network value type")), + }; + res.push(NetworkValue::from_network(Ok(serialized + [offset..offset + value_len] + .to_vec()))?); + offset += value_len; + } + Ok(res) } }