Skip to content

Commit

Permalink
adds more hard-coded tests for WeightedShuffle (#4517)
Browse files Browse the repository at this point in the history
Adding more hard-coded tests to ensure that changes to the code or the
dependencies (e.g. rand or rand_chacha::ChaChaRng) won't change the
deterministic shuffle.
  • Loading branch information
behzadnouri authored Jan 17, 2025
1 parent 9ea581a commit 8ea8a37
Showing 1 changed file with 108 additions and 1 deletion.
109 changes: 108 additions & 1 deletion gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,46 @@ fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_si
mod tests {
use {
super::*,
itertools::Itertools,
rand::SeedableRng,
rand_chacha::ChaChaRng,
std::{convert::TryInto, iter::repeat_with},
solana_sdk::hash::Hash,
std::{
convert::TryInto,
iter::{repeat_with, successors, Sum},
str::FromStr,
},
test_case::test_case,
};

fn verify_shuffle<T>(shuffle: &[usize], weights: &[T], mut mask: Vec<bool>)
where
T: ConstZero + Copy + PartialEq + PartialOrd + Sum<T>,
{
assert_eq!(weights.len(), mask.len());
let num_dropped = mask.iter().copied().map(usize::from).sum::<usize>();
// Assert that only the indices which were not dropped appear in
// the shuffle.
assert_eq!(shuffle.len(), weights.len() - num_dropped);
assert!(shuffle.iter().all(|&index| {
let out = !mask[index];
mask[index] = true;
out
}));
assert!(mask.iter().all(|&x| x));
// Assert that the random shuffle is weighted.
assert!(shuffle
.chunks(shuffle.len() / 10)
.map(|chunk| chunk.iter().map(|&i| weights[i]).sum::<T>())
.tuple_windows()
.all(|(a, b)| a > b));
// Assert that zero weights only appear at the end of the shuffle.
assert!(shuffle
.iter()
.tuple_windows()
.all(|(&i, &j)| weights[i] != T::ZERO || weights[j] == T::ZERO));
}

fn weighted_shuffle_slow<R>(rng: &mut R, mut weights: Vec<u64>) -> Vec<usize>
where
R: Rng,
Expand Down Expand Up @@ -411,6 +446,78 @@ mod tests {
assert_eq!(shuffle.first(&mut rng), Some(17));
}

// Verifies that changes to the code or dependencies (e.g. rand or
// rand_chacha::ChaChaRng) do not change the deterministic shuffle.
#[test_case(0x587c27258191c66d, "84jN8bvnp6mvtngzt42SW8AtRf5fcv3VBerKkUsYrCVG")]
#[test_case(0x7dad2afc68808779, "25oFhs9sR3WYfB6ohy752JrbLqpBjw6X4Eszbcsoxon4")]
#[test_case(0xfdd71c99c936736c, "7H9H8V7ccmpBhC3i5vEeFfiUwvRSAvRWadZhFH5ecSD7")]
#[test_case(0xe2a4d9fdd186636c, "Nxe6X7f74kEPrJFycKFcxByDRWKJtx1J3vsdbum9VPv")]
#[test_case(0x19a0a360e9f3094d, "Ec6wiaqDuVc5AzZpq4GAZ6GLsRJvw9mAVWVrCpDoGaRm")]
#[test_case(0xc5e0204894ca50dc, "BqxDzSFw8rJRHnTZmsPRzF77G3xgfK4hD8JyYeAFfxuZ")]
#[test_case(0xf1336cf933eeda07, "3Ux2vciDFdgNqULpsQpXfpaxZykWmBFCseqX9dwpGnyH")]
#[test_case(0xe666e7514f37c7a1, "Fc3gAUgh2mD1se3kkhPnLMKpQCiARd2PSdGf7b2fDS2n")]
fn test_weighted_shuffle_hard_coded_paranoid(seed: u64, expected_hash: &str) {
let expected_hash = Hash::from_str(expected_hash).unwrap();
let mut rng = <[u8; 32]>::try_from(
successors(Some(seed), |seed| Some(seed + 1))
.map(u64::to_le_bytes)
.take(32 / 8)
.flatten()
.collect::<Vec<u8>>(),
)
.map(ChaChaRng::from_seed)
.unwrap();
let num_weights = rng.gen_range(1..=100_000);
assert!((8143..=85348).contains(&num_weights), "{num_weights}");
let weights: Vec<u64> = repeat_with(|| {
if rng.gen_ratio(1, 100) {
0u64 // 1% zero weights.
} else {
rng.gen_range(0..=(u64::MAX / num_weights as u64))
}
})
.take(num_weights)
.collect();
let num_zeros = weights.iter().filter(|&&w| w == 0).count();
assert!((72..=846).contains(&num_zeros), "{num_zeros}");
// Assert that the sum of weights does not overflow.
assert_eq!(
weights.iter().fold(0u64, |a, &b| a.checked_add(b).unwrap()),
weights.iter().sum::<u64>()
);
let mut shuffle = WeightedShuffle::new("", &weights);
let shuffle1 = shuffle.clone().shuffle(&mut rng).collect::<Vec<_>>();
// Assert that all indices appear in the shuffle.
assert_eq!(shuffle1.len(), num_weights);
verify_shuffle(&shuffle1, &weights, vec![false; num_weights]);
// Drop some of the weights and re-shuffle.
let num_drops = rng.gen_range(1..1_000);
assert!((253..=981).contains(&num_drops), "{num_drops}");
let mut mask = vec![false; num_weights];
repeat_with(|| rng.gen_range(0..num_weights))
.filter(|&index| {
if mask[index] {
false
} else {
mask[index] = true;
true
}
})
.take(num_drops)
.for_each(|index| shuffle.remove_index(index));
let shuffle2 = shuffle.shuffle(&mut rng).collect::<Vec<_>>();
assert_eq!(shuffle2.len(), num_weights - num_drops);
verify_shuffle(&shuffle2, &weights, mask);
// Assert that code or dependencies updates do not change the shuffle.
let bytes = shuffle1
.into_iter()
.chain(shuffle2)
.map(usize::to_le_bytes)
.collect::<Vec<_>>();
let bytes = bytes.iter().map(AsRef::as_ref).collect::<Vec<_>>();
assert_eq!(solana_sdk::hash::hashv(&bytes[..]), expected_hash);
}

#[test]
fn test_weighted_shuffle_match_slow() {
let mut rng = rand::thread_rng();
Expand Down

0 comments on commit 8ea8a37

Please sign in to comment.