Skip to content

Commit

Permalink
removes the vector allocation in shred::merkle::make_merkle_proof (#4481
Browse files Browse the repository at this point in the history
)

The commit returns an iterator from shred::merkle::make_merkle_proof.
The Merkle proof entries are then directly written to shred payload.
  • Loading branch information
behzadnouri authored Jan 17, 2025
1 parent 09ecef4 commit 3904356
Showing 1 changed file with 72 additions and 37 deletions.
109 changes: 72 additions & 37 deletions ledger/src/shred/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,29 @@ impl Shred {
dispatch!(fn payload(&self) -> &Vec<u8>);
dispatch!(fn sanitize(&self) -> Result<(), Error>);
dispatch!(fn set_chained_merkle_root(&mut self, chained_merkle_root: &Hash) -> Result<(), Error>);
dispatch!(fn set_merkle_proof(&mut self, proof: &[&MerkleProofEntry]) -> Result<(), Error>);
dispatch!(fn set_retransmitter_signature(&mut self, signature: &Signature) -> Result<(), Error>);
dispatch!(fn set_signature(&mut self, signature: Signature));
dispatch!(fn signed_data(&self) -> Result<Hash, Error>);

#[inline]
fn merkle_proof(&self) -> Result<impl Iterator<Item = &MerkleProofEntry>, Error> {
match self {
Self::ShredCode(shred) => shred.merkle_proof().map(Either::Left),
Self::ShredData(shred) => shred.merkle_proof().map(Either::Right),
}
}

#[inline]
fn set_merkle_proof<'a, I>(&mut self, proof: I) -> Result<(), Error>
where
I: IntoIterator<Item = Result<&'a MerkleProofEntry, Error>>,
{
match self {
Self::ShredCode(shred) => shred.set_merkle_proof(proof),
Self::ShredData(shred) => shred.set_merkle_proof(proof),
}
}

#[must_use]
fn verify(&self, pubkey: &Pubkey) -> bool {
match self.signed_data() {
Expand Down Expand Up @@ -341,21 +352,29 @@ macro_rules! impl_merkle_shred {
get_merkle_node(&self.payload, SIZE_OF_SIGNATURE..proof_offset)
}

fn set_merkle_proof(&mut self, proof: &[&MerkleProofEntry]) -> Result<(), Error> {
fn set_merkle_proof<'a, I>(&mut self, proof: I) -> Result<(), Error>
where
I: IntoIterator<Item = Result<&'a MerkleProofEntry, Error>>,
{
let proof_size = self.proof_size()?;
if proof.len() != usize::from(proof_size) {
return Err(Error::InvalidMerkleProof);
}
let proof_offset = self.proof_offset()?;
let mut cursor = Cursor::new(
self.payload
.get_mut(proof_offset..)
.ok_or(Error::InvalidProofSize(proof_size))?,
);
proof
.iter()
.try_for_each(|entry| cursor.write_all(&entry[..]))
.map_err(Error::from)
let proof_size = usize::from(proof_size);
proof.into_iter().enumerate().try_for_each(|(k, entry)| {
if k >= proof_size {
return Err(Error::InvalidMerkleProof);
}
Ok(cursor.write_all(&entry?[..])?)
})?;
// Verify that exactly proof_size many entries are written.
if cursor.position() as usize != proof_size * SIZE_OF_MERKLE_PROOF_ENTRY {
return Err(Error::InvalidMerkleProof);
}
Ok(())
}

pub(super) fn retransmitter_signature(&self) -> Result<Signature, Error> {
Expand Down Expand Up @@ -687,21 +706,29 @@ fn make_merkle_proof(
mut index: usize, // leaf index ~ shred's erasure shard index.
mut size: usize, // number of leaves ~ erasure batch size.
tree: &[Hash],
) -> Option<Vec<&MerkleProofEntry>> {
if index >= size {
return None;
}
) -> impl Iterator<Item = Result<&MerkleProofEntry, Error>> {
let mut offset = 0;
let mut proof = Vec::<&MerkleProofEntry>::new();
while size > 1 {
let node = tree.get(offset + (index ^ 1).min(size - 1))?;
let entry = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
proof.push(<&MerkleProofEntry>::try_from(entry).unwrap());
offset += size;
size = (size + 1) >> 1;
index >>= 1;
if index >= size {
// Force below iterator to return Error.
(size, offset) = (0, tree.len());
}
(offset + 1 == tree.len()).then_some(proof)
std::iter::from_fn(move || {
if size > 1 {
let Some(node) = tree.get(offset + (index ^ 1).min(size - 1)) else {
return Some(Err(Error::InvalidMerkleProof));
};
offset += size;
size = (size + 1) >> 1;
index >>= 1;
let entry = &node.as_ref()[..SIZE_OF_MERKLE_PROOF_ENTRY];
let entry = <&MerkleProofEntry>::try_from(entry).unwrap();
Some(Ok(entry))
} else if offset + 1 == tree.len() {
None
} else {
Some(Err(Error::InvalidMerkleProof))
}
})
}

pub(super) fn recover(
Expand Down Expand Up @@ -878,16 +905,13 @@ pub(super) fn recover(
return Err(Error::InvalidMerkleRoot);
}
for (index, (shred, mask)) in shreds.iter_mut().zip(&mask).enumerate() {
let proof = make_merkle_proof(index, num_shards, &tree).ok_or(Error::InvalidMerkleProof)?;
if proof.len() != usize::from(proof_size) {
return Err(Error::InvalidMerkleProof);
}
let proof = make_merkle_proof(index, num_shards, &tree);
if *mask {
if shred.merkle_proof()?.ne(proof) {
if shred.merkle_proof()?.map(Some).ne(proof.map(Result::ok)) {
return Err(Error::InvalidMerkleProof);
}
} else {
shred.set_merkle_proof(&proof)?;
shred.set_merkle_proof(proof)?;
// Already sanitized after reconstruct.
debug_assert_matches!(shred.sanitize(), Ok(()));
// Assert that shred payload is fully populated.
Expand Down Expand Up @@ -1335,10 +1359,8 @@ fn make_erasure_batch(
let signature = keypair.sign_message(root.as_ref());
// Populate merkle proof for all shreds and attach signature.
for (index, shred) in shreds.iter_mut().enumerate() {
let proof =
make_merkle_proof(index, erasure_batch_size, &tree).ok_or(Error::InvalidMerkleProof)?;
debug_assert_eq!(proof.len(), usize::from(proof_size));
shred.set_merkle_proof(&proof)?;
let proof = make_merkle_proof(index, erasure_batch_size, &tree);
shred.set_merkle_proof(proof)?;
shred.set_signature(signature);
debug_assert!(shred.verify(&keypair.pubkey()));
debug_assert_matches!(shred.sanitize(), Ok(()));
Expand Down Expand Up @@ -1455,15 +1477,29 @@ mod test {
}
}

#[test]
fn test_make_merkle_proof_error() {
let mut rng = rand::thread_rng();
let nodes = repeat_with(|| rng.gen::<[u8; 32]>()).map(Hash::from);
let nodes: Vec<_> = nodes.take(5).collect();
let size = nodes.len();
let tree = make_merkle_tree(nodes.into_iter().map(Ok)).unwrap();
for index in size..size + 3 {
assert_matches!(
make_merkle_proof(index, size, &tree).next(),
Some(Err(Error::InvalidMerkleProof))
);
}
}

fn run_merkle_tree_round_trip<R: Rng>(rng: &mut R, size: usize) {
let nodes = repeat_with(|| rng.gen::<[u8; 32]>()).map(Hash::from);
let nodes: Vec<_> = nodes.take(size).collect();
let tree = make_merkle_tree(nodes.iter().cloned().map(Ok)).unwrap();
let root = tree.last().copied().unwrap();
for index in 0..size {
let proof = make_merkle_proof(index, size, &tree).unwrap();
for (k, &node) in nodes.iter().enumerate() {
let proof = proof.iter().copied();
let proof = make_merkle_proof(index, size, &tree).map(Result::unwrap);
if k == index {
assert_eq!(root, get_merkle_root(k, node, proof).unwrap());
} else {
Expand Down Expand Up @@ -1615,9 +1651,8 @@ mod test {
let nodes = shreds.iter().map(Shred::merkle_node);
let tree = make_merkle_tree(nodes).unwrap();
for (index, shred) in shreds.iter_mut().enumerate() {
let proof = make_merkle_proof(index, num_shreds, &tree).unwrap();
assert_eq!(proof.len(), usize::from(proof_size));
shred.set_merkle_proof(&proof).unwrap();
let proof = make_merkle_proof(index, num_shreds, &tree);
shred.set_merkle_proof(proof).unwrap();
let data = shred.signed_data().unwrap();
let signature = keypair.sign_message(data.as_ref());
shred.set_signature(signature);
Expand Down

0 comments on commit 3904356

Please sign in to comment.