Skip to content

Commit

Permalink
Major bugfix (#144)
Browse files Browse the repository at this point in the history
Trivial bug critical bug in the deserialization of the chitchat digest.
This PR adds unit test, and adds serialization to the channel transport.
  • Loading branch information
fulmicoton authored Mar 29, 2024
1 parent f783620 commit d039699
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 20 deletions.
77 changes: 58 additions & 19 deletions chitchat/src/digest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,30 @@ pub(crate) struct NodeDigest {
pub(crate) max_version: Version,
}

impl NodeDigest {
pub(crate) fn new(
heartbeat: Heartbeat,
last_gc_version: Version,
max_version: Version,
) -> Self {
Self {
impl Serializable for NodeDigest {
fn serialize(&self, buf: &mut Vec<u8>) {
self.heartbeat.serialize(buf);
self.last_gc_version.serialize(buf);
self.max_version.serialize(buf);
}

fn serialized_len(&self) -> usize {
self.heartbeat.serialized_len()
+ self.last_gc_version.serialized_len()
+ self.max_version.serialized_len()
}
}

impl Deserializable for NodeDigest {
fn deserialize(buf: &mut &[u8]) -> anyhow::Result<Self> {
let heartbeat = Heartbeat::deserialize(buf)?;
let last_gc_version = Version::deserialize(buf)?;
let max_version = Version::deserialize(buf)?;
Ok(NodeDigest {
heartbeat,
last_gc_version,
max_version,
}
})
}
}

Expand All @@ -43,7 +56,11 @@ impl Digest {
last_gc_version: Version,
max_version: Version,
) {
let node_digest = NodeDigest::new(heartbeat, last_gc_version, max_version);
let node_digest = NodeDigest {
heartbeat,
last_gc_version,
max_version,
};
self.node_digests.insert(node, node_digest);
}
}
Expand All @@ -53,18 +70,14 @@ impl Serializable for Digest {
(self.node_digests.len() as u16).serialize(buf);
for (chitchat_id, node_digest) in &self.node_digests {
chitchat_id.serialize(buf);
node_digest.heartbeat.serialize(buf);
node_digest.last_gc_version.serialize(buf);
node_digest.max_version.serialize(buf);
node_digest.serialize(buf);
}
}
fn serialized_len(&self) -> usize {
let mut len = (self.node_digests.len() as u16).serialized_len();
for (chitchat_id, node_digest) in &self.node_digests {
len += chitchat_id.serialized_len();
len += node_digest.heartbeat.serialized_len();
len += node_digest.last_gc_version.serialized_len();
len += node_digest.max_version.serialized_len();
len += node_digest.serialized_len();
}
len
}
Expand All @@ -77,12 +90,38 @@ impl Deserializable for Digest {

for _ in 0..num_nodes {
let chitchat_id = ChitchatId::deserialize(buf)?;
let heartbeat = Heartbeat::deserialize(buf)?;
let max_version = u64::deserialize(buf)?;
let last_gc_version = u64::deserialize(buf)?;
let node_digest = NodeDigest::new(heartbeat, last_gc_version, max_version);
let node_digest = NodeDigest::deserialize(buf)?;
node_digests.insert(chitchat_id, node_digest);
}
Ok(Digest { node_digests })
}
}

#[cfg(test)]
mod tests {
use crate::digest::{Digest, NodeDigest};
use crate::serialize::test_serdeser_aux;
use crate::{ChitchatId, Heartbeat};

#[test]
fn test_node_digest_serialization() {
let node_digest = NodeDigest {
heartbeat: crate::Heartbeat(100u64),
last_gc_version: 2,
max_version: 3,
};
test_serdeser_aux(&node_digest, 24);
}

#[test]
fn test_digest_serialization() {
let mut digest = Digest::default();
let node1 = ChitchatId::for_local_test(10_001);
let node2 = ChitchatId::for_local_test(10_002);
let node3 = ChitchatId::for_local_test(10_002);
digest.add_node(node1, Heartbeat(101), 1, 11);
digest.add_node(node2, Heartbeat(102), 20, 12);
digest.add_node(node3, Heartbeat(103), 0, 13);
test_serdeser_aux(&digest, 104);
}
}
14 changes: 13 additions & 1 deletion chitchat/src/transport/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use async_trait::async_trait;
use tokio::sync::mpsc::{Receiver, Sender};
use tracing::info;

use crate::serialize::Serializable;
use crate::serialize::{Deserializable, Serializable};
use crate::transport::{Socket, Transport};
use crate::ChitchatMessage;

Expand Down Expand Up @@ -56,6 +56,16 @@ impl Transport for ChannelTransport {
}
}

fn serialize_deserialize_chitchat_message(message: ChitchatMessage) -> ChitchatMessage {
let buf = message.serialize_to_vec();
assert_eq!(buf.len(), message.serialized_len());
let mut read_cursor: &[u8] = &buf[..];
let message_ser_deser = ChitchatMessage::deserialize(&mut read_cursor).unwrap();
assert_eq!(message, message_ser_deser);
assert!(read_cursor.is_empty());
message
}

impl ChannelTransport {
pub fn with_mtu(mtu: usize) -> Self {
Self {
Expand Down Expand Up @@ -92,6 +102,8 @@ impl ChannelTransport {
to_addr: SocketAddr,
message: ChitchatMessage,
) -> anyhow::Result<()> {
// We serialize/deserialize message to get closer to the real world.
let message = serialize_deserialize_chitchat_message(message);
let num_bytes = message.serialized_len();
if let Some(mtu) = self.mtu_opt {
if num_bytes > mtu {
Expand Down

0 comments on commit d039699

Please sign in to comment.