diff --git a/chitchat/Cargo.toml b/chitchat/Cargo.toml index a3a79a8..dc68655 100644 --- a/chitchat/Cargo.toml +++ b/chitchat/Cargo.toml @@ -16,9 +16,16 @@ bytes = "1" itertools = "0.12" rand = { version = "0.8", features = ["small_rng"] } serde = { version = "1", features = ["derive"] } -tokio = { version = "1.28.0", features = ["net", "sync", "rt-multi-thread", "macros", "time"] } +tokio = { version = "1.28.0", features = [ + "net", + "sync", + "rt-multi-thread", + "macros", + "time", +] } tokio-stream = { version = "0.1", features = ["sync"] } tracing = "0.1" +zstd = "0.13" [dev-dependencies] assert-json-diff = "2" diff --git a/chitchat/src/delta.rs b/chitchat/src/delta.rs index ced1fe6..9da3fdf 100644 --- a/chitchat/src/delta.rs +++ b/chitchat/src/delta.rs @@ -10,6 +10,98 @@ pub struct Delta { pub(crate) nodes_to_reset: HashSet, } +enum DeltaOp { + Node(ChitchatId), + KeyValue { + key: String, + versioned_value: VersionedValue, + }, + NodesToReset(ChitchatId), +} + +impl Deserializable for DeltaOp { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let tag_bytes: [u8; 1] = Deserializable::deserialize(buf)?; + let tag = tag_bytes[0]; + match tag { + 0 => { + let chitchat_id = ChitchatId::deserialize(buf)?; + Ok(DeltaOp::Node(chitchat_id)) + } + 1 => { + let key = String::deserialize(buf)?; + let value = String::deserialize(buf)?; + let version = u64::deserialize(buf)?; + let tombstone = Option::::deserialize(buf)?; + let versioned_value: VersionedValue = VersionedValue { + value, + version, + tombstone, + }; + Ok(DeltaOp::KeyValue { + key, + versioned_value, + }) + } + 2 => { + let chitchat_id = ChitchatId::deserialize(buf)?; + Ok(DeltaOp::NodesToReset(chitchat_id)) + } + _ => Err(anyhow::anyhow!("Invalid tag: {}", tag)), + } + } +} + +enum DeltaOpRef<'a> { + Node(&'a ChitchatId), + KeyValue { + key: &'a str, + versioned_value: &'a VersionedValue, + }, + NodesToReset(&'a ChitchatId), +} + +impl<'a> Serializable for DeltaOpRef<'a> { + fn serialize(&self, buf: &mut Vec) { + match self { + DeltaOpRef::Node(chitchat_id) => { + buf.push(0u8); + chitchat_id.serialize(buf); + } + DeltaOpRef::KeyValue { + key, + versioned_value, + } => { + buf.push(1u8); + key.serialize(buf); + versioned_value.value.serialize(buf); + versioned_value.version.serialize(buf); + versioned_value.tombstone.serialize(buf); + } + DeltaOpRef::NodesToReset(chitchat_id) => { + buf.push(2u8); + chitchat_id.serialize(buf); + } + } + } + + fn serialized_len(&self) -> usize { + 1 + match self { + DeltaOpRef::Node(chitchat_id) => chitchat_id.serialized_len(), + DeltaOpRef::KeyValue { + key, + versioned_value, + } => { + key.serialized_len() + + versioned_value.value.serialized_len() + + versioned_value.version.serialized_len() + + versioned_value.tombstone.serialized_len() + } + DeltaOpRef::NodesToReset(chitchat_id) => chitchat_id.serialized_len(), + } + } +} + impl Serializable for Delta { fn serialize(&self, buf: &mut Vec) { (self.node_deltas.len() as u16).serialize(buf); @@ -23,6 +115,21 @@ impl Serializable for Delta { } } + fn serialized_len(&self) -> usize { + let mut len = 2; + for (chitchat_id, node_delta) in &self.node_deltas { + len += chitchat_id.serialized_len(); + len += node_delta.serialized_len(); + } + len += 2; + for chitchat_id in &self.nodes_to_reset { + len += chitchat_id.serialized_len(); + } + len + } +} + +impl Deserializable for Delta { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let mut node_deltas: BTreeMap = Default::default(); let num_nodes = u16::deserialize(buf)?; @@ -42,19 +149,6 @@ impl Serializable for Delta { nodes_to_reset, }) } - - fn serialized_len(&self) -> usize { - let mut len = 2; - for (chitchat_id, node_delta) in &self.node_deltas { - len += chitchat_id.serialized_len(); - len += node_delta.serialized_len(); - } - len += 2; - for chitchat_id in &self.nodes_to_reset { - len += chitchat_id.serialized_len(); - } - len - } } #[cfg(test)] @@ -234,7 +328,29 @@ impl Serializable for NodeDelta { tombstone.serialize(buf); } } + fn serialized_len(&self) -> usize { + let mut len = 2; + len += self.heartbeat.serialized_len(); + for ( + key, + VersionedValue { + value, + version, + tombstone, + }, + ) in &self.key_values + { + len += key.serialized_len(); + len += value.serialized_len(); + len += version.serialized_len(); + len += tombstone.serialized_len(); + } + len + } +} + +impl Deserializable for NodeDelta { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let heartbeat = Heartbeat::deserialize(buf)?; let mut key_values: BTreeMap = Default::default(); @@ -261,27 +377,6 @@ impl Serializable for NodeDelta { max_version, }) } - - fn serialized_len(&self) -> usize { - let mut len = 2; - len += self.heartbeat.serialized_len(); - - for ( - key, - VersionedValue { - value, - version, - tombstone, - }, - ) in &self.key_values - { - len += key.serialized_len(); - len += value.serialized_len(); - len += version.serialized_len(); - len += tombstone.serialized_len(); - } - len - } } #[cfg(test)] diff --git a/chitchat/src/digest.rs b/chitchat/src/digest.rs index 2371dd7..c369515 100644 --- a/chitchat/src/digest.rs +++ b/chitchat/src/digest.rs @@ -45,7 +45,18 @@ impl Serializable for Digest { node_digest.max_version.serialize(buf); } } + fn serialized_len(&self) -> usize { + let mut len = (self.node_digests.len() as u16).serialized_len(); + for (chitchat_id, node_digest) in &self.node_digests { + len += chitchat_id.serialized_len(); + len += node_digest.heartbeat.serialized_len(); + len += node_digest.max_version.serialized_len(); + } + len + } +} +impl Deserializable for Digest { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let num_nodes = u16::deserialize(buf)?; let mut node_digests: BTreeMap = Default::default(); @@ -59,14 +70,4 @@ impl Serializable for Digest { } Ok(Digest { node_digests }) } - - fn serialized_len(&self) -> usize { - let mut len = (self.node_digests.len() as u16).serialized_len(); - for (chitchat_id, node_digest) in &self.node_digests { - len += chitchat_id.serialized_len(); - len += node_digest.heartbeat.serialized_len(); - len += node_digest.max_version.serialized_len(); - } - len - } } diff --git a/chitchat/src/message.rs b/chitchat/src/message.rs index 863b040..f3c2dfe 100644 --- a/chitchat/src/message.rs +++ b/chitchat/src/message.rs @@ -4,7 +4,7 @@ use anyhow::Context; use crate::delta::Delta; use crate::digest::Digest; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; /// Chitchat message. /// @@ -73,6 +73,19 @@ impl Serializable for ChitchatMessage { } } + fn serialized_len(&self) -> usize { + match self { + ChitchatMessage::Syn { cluster_id, digest } => { + 1 + cluster_id.serialized_len() + digest.serialized_len() + } + ChitchatMessage::SynAck { digest, delta } => syn_ack_serialized_len(digest, delta), + ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), + ChitchatMessage::BadCluster => 1, + } + } +} + +impl Deserializable for ChitchatMessage { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let code = buf .first() @@ -98,17 +111,6 @@ impl Serializable for ChitchatMessage { MessageType::BadCluster => Ok(Self::BadCluster), } } - - fn serialized_len(&self) -> usize { - match self { - ChitchatMessage::Syn { cluster_id, digest } => { - 1 + cluster_id.serialized_len() + digest.serialized_len() - } - ChitchatMessage::SynAck { digest, delta } => syn_ack_serialized_len(digest, delta), - ChitchatMessage::Ack { delta } => 1 + delta.serialized_len(), - ChitchatMessage::BadCluster => 1, - } - } } pub(crate) fn syn_ack_serialized_len(digest: &Digest, delta: &Delta) -> usize { diff --git a/chitchat/src/serialize.rs b/chitchat/src/serialize.rs index 5536383..43b011d 100644 --- a/chitchat/src/serialize.rs +++ b/chitchat/src/serialize.rs @@ -1,7 +1,10 @@ use std::io::BufRead; +use std::marker::PhantomData; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -use anyhow::bail; +use anyhow::{bail, Context}; +use bytes::Buf; +use zstd; use crate::{ChitchatId, Heartbeat}; @@ -10,7 +13,7 @@ use crate::{ChitchatId, Heartbeat}; /// Chitchat uses a custom binary serialization format. /// The point of this format is to make it possible /// to truncate the delta payload to a given mtu. -pub trait Serializable: Sized { +pub trait Serializable { fn serialize(&self, buf: &mut Vec); fn serialize_to_vec(&self) -> Vec { @@ -19,19 +22,33 @@ pub trait Serializable: Sized { buf } - fn deserialize(buf: &mut &[u8]) -> anyhow::Result; - fn serialized_len(&self) -> usize; } -impl Serializable for u16 { +pub trait Deserializable: Sized { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result; +} + +impl Serializable for u8 { fn serialize(&self, buf: &mut Vec) { - self.to_le_bytes().serialize(buf); + buf.push(*self) + } + + fn serialized_len(&self) -> usize { + 1 } +} +impl Deserializable for u8 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u16_bytes: [u8; 2] = Serializable::deserialize(buf)?; - Ok(Self::from_le_bytes(u16_bytes)) + let byte: [u8; 1] = Deserializable::deserialize(buf)?; + Ok(byte[0]) + } +} + +impl Serializable for u16 { + fn serialize(&self, buf: &mut Vec) { + self.to_le_bytes().serialize(buf); } fn serialized_len(&self) -> usize { @@ -39,19 +56,27 @@ impl Serializable for u16 { } } +impl Deserializable for u16 { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let u16_bytes: [u8; 2] = Deserializable::deserialize(buf)?; + Ok(Self::from_le_bytes(u16_bytes)) + } +} + impl Serializable for u64 { fn serialize(&self, buf: &mut Vec) { self.to_le_bytes().serialize(buf); } + fn serialized_len(&self) -> usize { + 8 + } +} +impl Deserializable for u64 { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let u64_bytes: [u8; 8] = Serializable::deserialize(buf)?; + let u64_bytes: [u8; 8] = Deserializable::deserialize(buf)?; Ok(Self::from_le_bytes(u64_bytes)) } - - fn serialized_len(&self) -> usize { - 8 - } } impl Serializable for Option { @@ -61,16 +86,6 @@ impl Serializable for Option { tombstone.serialize(buf); } } - - fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let is_some: bool = Serializable::deserialize(buf)?; - if is_some { - let u64_value = Serializable::deserialize(buf)?; - return Ok(Some(u64_value)); - } - Ok(None) - } - fn serialized_len(&self) -> usize { if self.is_some() { 9 @@ -80,19 +95,31 @@ impl Serializable for Option { } } +impl Deserializable for Option { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let is_some: bool = Deserializable::deserialize(buf)?; + if is_some { + let u64_value = Deserializable::deserialize(buf)?; + return Ok(Some(u64_value)); + } + Ok(None) + } +} + impl Serializable for bool { fn serialize(&self, buf: &mut Vec) { buf.push(*self as u8); } + fn serialized_len(&self) -> usize { + 1 + } +} +impl Deserializable for bool { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let bool_byte: [u8; 1] = Serializable::deserialize(buf)?; + let bool_byte: [u8; 1] = Deserializable::deserialize(buf)?; Ok(bool_byte[0] != 0) } - - fn serialized_len(&self) -> usize { - 1 - } } #[repr(u8)] @@ -129,42 +156,55 @@ impl Serializable for IpAddr { } } + fn serialized_len(&self) -> usize { + 1 + match self { + IpAddr::V4(_) => 4, + IpAddr::V6(_) => 16, + } + } +} + +impl Deserializable for IpAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { - let ip_version_byte: [u8; 1] = Serializable::deserialize(buf)?; + let ip_version_byte: [u8; 1] = Deserializable::deserialize(buf)?; let ip_version = IpVersion::try_from(ip_version_byte[0])?; - match ip_version { IpVersion::V4 => { - let bytes: [u8; 4] = Serializable::deserialize(buf)?; + let bytes: [u8; 4] = Deserializable::deserialize(buf)?; Ok(Ipv4Addr::from(bytes).into()) } IpVersion::V6 => { - let bytes: [u8; 16] = Serializable::deserialize(buf)?; + let bytes: [u8; 16] = Deserializable::deserialize(buf)?; Ok(Ipv6Addr::from(bytes).into()) } } } - - fn serialized_len(&self) -> usize { - 1 + match self { - IpAddr::V4(_) => 4, - IpAddr::V6(_) => 16, - } - } } impl Serializable for String { fn serialize(&self, buf: &mut Vec) { - (self.len() as u16).serialize(buf); - buf.extend(self.as_bytes()) + self.as_str().serialize(buf) } + fn serialized_len(&self) -> usize { + self.as_str().serialized_len() + } +} + +impl Deserializable for String { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let len: usize = u16::deserialize(buf)? as usize; let s = std::str::from_utf8(&buf[..len])?.to_string(); buf.consume(len); Ok(s) } +} + +impl<'a> Serializable for &'a str { + fn serialize(&self, buf: &mut Vec) { + (self.len() as u16).serialize(buf); + buf.extend(self.as_bytes()) + } fn serialized_len(&self) -> usize { 2 + self.len() @@ -175,7 +215,12 @@ impl Serializable for [u8; N] { fn serialize(&self, buf: &mut Vec) { buf.extend_from_slice(&self[..]); } + fn serialized_len(&self) -> usize { + N + } +} +impl Deserializable for [u8; N] { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { if buf.len() < N { bail!("Buffer too short"); @@ -184,10 +229,6 @@ impl Serializable for [u8; N] { buf.consume(N); Ok(val_bytes) } - - fn serialized_len(&self) -> usize { - N - } } impl Serializable for SocketAddr { @@ -196,15 +237,17 @@ impl Serializable for SocketAddr { self.port().serialize(buf); } + fn serialized_len(&self) -> usize { + self.ip().serialized_len() + self.port().serialized_len() + } +} + +impl Deserializable for SocketAddr { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let ip_addr = IpAddr::deserialize(buf)?; let port = u16::deserialize(buf)?; Ok(SocketAddr::new(ip_addr, port)) } - - fn serialized_len(&self) -> usize { - self.ip().serialized_len() + self.port().serialized_len() - } } impl Serializable for ChitchatId { @@ -214,6 +257,14 @@ impl Serializable for ChitchatId { self.gossip_advertise_addr.serialize(buf) } + fn serialized_len(&self) -> usize { + self.node_id.serialized_len() + + self.generation_id.serialized_len() + + self.gossip_advertise_addr.serialized_len() + } +} + +impl Deserializable for ChitchatId { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let node_id = String::deserialize(buf)?; let generation_id = u64::deserialize(buf)?; @@ -224,12 +275,6 @@ impl Serializable for ChitchatId { gossip_advertise_addr, }) } - - fn serialized_len(&self) -> usize { - self.node_id.serialized_len() - + self.generation_id.serialized_len() - + self.gossip_advertise_addr.serialized_len() - } } impl Serializable for Heartbeat { @@ -237,19 +282,215 @@ impl Serializable for Heartbeat { self.0.serialize(buf); } + fn serialized_len(&self) -> usize { + self.0.serialized_len() + } +} + +impl Deserializable for Heartbeat { fn deserialize(buf: &mut &[u8]) -> anyhow::Result { let heartbeat = u64::deserialize(buf)?; Ok(Self(heartbeat)) } +} + +pub struct CompressedStreamWriter { + output: Vec, + // number of blocks written in output. + num_blocks: u16, + + // temporary buffer used for block compression. + uncompressed_block: Vec, + // ongoing block being serialized. + compressed_block: Vec, + block_threshold: usize, + _phantom: PhantomData, +} + +impl CompressedStreamWriter { + pub fn with_block_threshold(block_threshold: u16) -> CompressedStreamWriter { + let block_threshold = block_threshold as usize; + let output = Vec::with_capacity(block_threshold); + CompressedStreamWriter { + output, + uncompressed_block: Vec::with_capacity(block_threshold * 2), + compressed_block: Vec::with_capacity(block_threshold), + block_threshold, + num_blocks: 0, + _phantom: PhantomData, + } + } + + /// Returns an upperbound of the serialized len after appending `s` + pub fn serialized_len_upperbound_after(&self, item: &S) -> usize { + self.output.len() + // already serialized block + 3 + // current block len + self.uncompressed_block.len() + + 3 + // possibly another block that will be created + item.serialized_len() + // the new item. This assume no compression will be possible. + 1 // End of stream flag + } + + pub fn append(&mut self, item: S) { + let item_len = item.serialized_len(); + assert!(item_len <= u16::MAX as usize); + if self.uncompressed_block.len() + item_len >= self.block_threshold { + // time to flush our current block. + self.flush_block(); + } + item.serialize(&mut self.uncompressed_block); + if self.uncompressed_block.len() >= self.block_threshold {} + } + + /// Flush the ongoing block as compressed or an uncompressed block (whichever is the smallest). + /// If the ongoing block is empty, this function is no op. + fn flush_block(&mut self) { + if self.uncompressed_block.is_empty() { + return; + } + let uncompressed_len = self.uncompressed_block.len(); + let uncompressed_len_u16 = + u16::try_from(uncompressed_len).expect("uncompressed block too big"); + self.compressed_block.resize(uncompressed_len, 0u8); + match zstd::bulk::compress_to_buffer( + &self.uncompressed_block, + &mut self.compressed_block[..], + 0, + ) { + Ok(compressed_len) => { + let compressed_len_u16 = u16::try_from(compressed_len).unwrap(); + let block_meta = BlockMeta::CompressedBlock { + len: compressed_len as u16, + }; + block_meta.serialize(&mut self.output); + self.output.extend(&self.compressed_block[..compressed_len]); + } + // The compressed version was actually longer than the decompressed one. + // Let's keep the block uncomopressed + Err(_) => { + let block_meta = BlockMeta::UncompressedBlock { + len: uncompressed_len_u16, + }; + block_meta.serialize(&mut self.output); + self.output.extend(&self.uncompressed_block); + } + } + self.num_blocks += 1; + self.uncompressed_block.clear(); + self.compressed_block.clear(); + } + + pub fn finalize(mut self) -> Vec { + self.flush_block(); + BlockMeta::NoMoreBlocks.serialize(&mut self.output); + self.output + } +} + +struct CompressedStream(Vec); + +impl Deserializable for CompressedStream { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let mut items: Vec = Vec::new(); + let mut decompression_buffer = vec![0; u16::MAX as usize]; + while !buf.is_empty() { + let block_meta = BlockMeta::deserialize(buf)?; + match block_meta { + BlockMeta::CompressedBlock { len } => { + let len = len as usize; + let compressed_block_bytes = &buf[..len]; + let uncompressed_len = zstd::bulk::decompress_to_buffer( + compressed_block_bytes, + &mut decompression_buffer[..u16::MAX as usize], + ) + .context("failed to decompress block")?; + buf.advance(len as usize); + let mut block_bytes = &decompression_buffer[..uncompressed_len]; + while !block_bytes.is_empty() { + let item = D::deserialize(&mut block_bytes)?; + items.push(item); + } + } + BlockMeta::UncompressedBlock { len } => { + let len = len as usize; + let mut block_bytes = &buf[..len]; + buf.advance(len as usize); + while !block_bytes.is_empty() { + let item = D::deserialize(&mut block_bytes)?; + items.push(item); + } + } + BlockMeta::NoMoreBlocks => { + return Ok(CompressedStream(items)); + } + }; + } + anyhow::bail!("compressed streams error: reached end of buffer without NoMoreBlock tag"); + } +} + +#[derive(Eq, PartialEq, Debug)] +enum BlockMeta { + CompressedBlock { len: u16 }, + UncompressedBlock { len: u16 }, + NoMoreBlocks, +} + +const NO_MORE_BLOCKS_TAG: u8 = 0u8; +const COMPRESSED_BLOCK_TAG: u8 = 1u8; +const UNCOMPRESSED_BLOCK_TAG: u8 = 2u8; + +impl Serializable for BlockMeta { + fn serialize(&self, buf: &mut Vec) { + match self { + BlockMeta::CompressedBlock { len } => { + COMPRESSED_BLOCK_TAG.serialize(buf); + len.serialize(buf); + } + BlockMeta::UncompressedBlock { len } => { + UNCOMPRESSED_BLOCK_TAG.serialize(buf); + len.serialize(buf); + } + BlockMeta::NoMoreBlocks => { + NO_MORE_BLOCKS_TAG.serialize(buf); + } + } + } fn serialized_len(&self) -> usize { - self.0.serialized_len() + match self { + BlockMeta::CompressedBlock { .. } | BlockMeta::UncompressedBlock { .. } => 3, + BlockMeta::NoMoreBlocks => 1, + } + } +} + +impl Deserializable for BlockMeta { + fn deserialize(buf: &mut &[u8]) -> anyhow::Result { + let tag = u8::deserialize(buf)?; + match tag { + UNCOMPRESSED_BLOCK_TAG => { + let len = u16::deserialize(buf)?; + Ok(BlockMeta::UncompressedBlock { len }) + } + COMPRESSED_BLOCK_TAG => { + let len = u16::deserialize(buf)?; + Ok(BlockMeta::CompressedBlock { len }) + } + NO_MORE_BLOCKS_TAG => Ok(BlockMeta::NoMoreBlocks), + _ => { + anyhow::bail!("Unknown block meta tag: {tag}") + } + } } } #[cfg(test)] #[track_caller] -pub fn test_serdeser_aux(obj: &T, num_bytes: usize) { +pub fn test_serdeser_aux( + obj: &T, + num_bytes: usize, +) { let mut buf = Vec::new(); obj.serialize(&mut buf); assert_eq!(buf.len(), obj.serialized_len()); @@ -296,4 +537,47 @@ mod tests { test_serdeser_aux(&Some(1), 9); test_serdeser_aux(&None, 1); } + + #[test] + fn test_serialize_block_meta() { + test_serdeser_aux(&BlockMeta::CompressedBlock { len: 10u16 }, 3); + test_serdeser_aux(&BlockMeta::UncompressedBlock { len: 18u16 }, 3); + test_serdeser_aux(&BlockMeta::NoMoreBlocks, 1); + } + + // An array of 10 small sentences for tests. + const TEXT_SAMPLES: [&str; 10] = [ + "I'm happy.", + "She exercises every morning.", + "His dog barks loudly.", + "My school starts at 8:00.", + "We always eat dinner together.", + "They take the bus to work.", + "He doesn't like vegetables.", + "I don't want anything to drink.", + "hello Happy tax payer", + "do you like tea?", + ]; + + #[test] + fn test_compressed_serialized_stream() { + let mut compressed_stream_writer: CompressedStreamWriter<&str> = + CompressedStreamWriter::with_block_threshold(1000); + let mut uncompressed_len = 0; + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + compressed_stream_writer.append(sentence); + uncompressed_len += sentence.len(); + } + let buf = compressed_stream_writer.finalize(); + let mut cursor = &buf[..]; + assert!(buf.len() * 3 < uncompressed_len); + let vals: CompressedStream = Deserializable::deserialize(&mut cursor).unwrap(); + assert_eq!(vals.0.len(), 100); + for i in 0..100 { + let sentence = TEXT_SAMPLES[i % TEXT_SAMPLES.len()]; + assert_eq!(&vals.0[i], sentence); + } + assert!(cursor.is_empty()); + } } diff --git a/chitchat/src/transport/udp.rs b/chitchat/src/transport/udp.rs index 327602d..cc9f795 100644 --- a/chitchat/src/transport/udp.rs +++ b/chitchat/src/transport/udp.rs @@ -4,7 +4,7 @@ use anyhow::Context; use async_trait::async_trait; use tracing::warn; -use crate::serialize::Serializable; +use crate::serialize::{Deserializable, Serializable}; use crate::transport::{Socket, Transport}; use crate::{ChitchatMessage, MAX_UDP_DATAGRAM_PAYLOAD_SIZE};