diff --git a/Cargo.lock b/Cargo.lock index cc42e6da02df20..540b5632db8627 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5828,6 +5828,7 @@ dependencies = [ "solana-streamer", "solana-svm", "solana-tpu-client", + "solana-transaction-metrics-tracker", "solana-transaction-status", "solana-turbine", "solana-unified-scheduler-pool", @@ -7190,9 +7191,11 @@ dependencies = [ "rand 0.8.5", "rustls", "solana-logger", + "solana-measure", "solana-metrics", "solana-perf", "solana-sdk", + "solana-transaction-metrics-tracker", "thiserror", "tokio", "x509-parser", @@ -7359,6 +7362,20 @@ dependencies = [ "solana-version", ] +[[package]] +name = "solana-transaction-metrics-tracker" +version = "1.19.0" +dependencies = [ + "Inflector", + "base64 0.21.7", + "bincode", + "lazy_static", + "log", + "rand 0.8.5", + "solana-perf", + "solana-sdk", +] + [[package]] name = "solana-transaction-status" version = "1.19.0" diff --git a/Cargo.toml b/Cargo.toml index 8cc38b69144d3d..8d467818a23b0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,6 +106,7 @@ members = [ "tokens", "tpu-client", "transaction-dos", + "transaction-metrics-tracker", "transaction-status", "turbine", "udp-client", @@ -378,6 +379,7 @@ solana-system-program = { path = "programs/system", version = "=1.19.0" } solana-test-validator = { path = "test-validator", version = "=1.19.0" } solana-thin-client = { path = "thin-client", version = "=1.19.0" } solana-tpu-client = { path = "tpu-client", version = "=1.19.0", default-features = false } +solana-transaction-metrics-tracker = { path = "transaction-metrics-tracker", version = "=1.19.0" } solana-transaction-status = { path = "transaction-status", version = "=1.19.0" } solana-turbine = { path = "turbine", version = "=1.19.0" } solana-udp-client = { path = "udp-client", version = "=1.19.0" } diff --git a/core/Cargo.toml b/core/Cargo.toml index e2a936cdabc4c1..1fd25ec38a8d3b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -67,6 +67,7 @@ solana-send-transaction-service = { workspace = true } solana-streamer = { workspace = true } solana-svm = { workspace = true } solana-tpu-client = { workspace = true } +solana-transaction-metrics-tracker = { workspace = true } solana-transaction-status = { workspace = true } solana-turbine = { workspace = true } solana-unified-scheduler-pool = { workspace = true } diff --git a/core/src/banking_stage/consumer.rs b/core/src/banking_stage/consumer.rs index f4ac6c6040eda8..9d7cdca17f35dd 100644 --- a/core/src/banking_stage/consumer.rs +++ b/core/src/banking_stage/consumer.rs @@ -208,6 +208,32 @@ impl Consumer { .slot_metrics_tracker .increment_retryable_packets_count(retryable_transaction_indexes.len() as u64); + // Now we track the performance for the interested transactions which is not in the retryable_transaction_indexes + // We assume the retryable_transaction_indexes is already sorted. + let mut retryable_idx = 0; + for (index, packet) in packets_to_process.iter().enumerate() { + if packet.original_packet().meta().is_perf_track_packet() { + if let Some(start_time) = packet.start_time() { + if retryable_idx >= retryable_transaction_indexes.len() + || retryable_transaction_indexes[retryable_idx] != index + { + let duration = Instant::now().duration_since(*start_time); + + debug!( + "Banking stage processing took {duration:?} for transaction {:?}", + packet.transaction().get_signatures().first() + ); + payload + .slot_metrics_tracker + .increment_process_sampled_packets_us(duration.as_micros() as u64); + } else { + // This packet is retried, advance the retry index to the next, as the next packet's index will + // certainly be > than this. + retryable_idx += 1; + } + } + } + } Some(retryable_transaction_indexes) } diff --git a/core/src/banking_stage/immutable_deserialized_packet.rs b/core/src/banking_stage/immutable_deserialized_packet.rs index 26ede7045d3480..6eb5d68ecaaca5 100644 --- a/core/src/banking_stage/immutable_deserialized_packet.rs +++ b/core/src/banking_stage/immutable_deserialized_packet.rs @@ -13,7 +13,7 @@ use { VersionedTransaction, }, }, - std::{cmp::Ordering, mem::size_of, sync::Arc}, + std::{cmp::Ordering, mem::size_of, sync::Arc, time::Instant}, thiserror::Error, }; @@ -41,10 +41,16 @@ pub struct ImmutableDeserializedPacket { message_hash: Hash, is_simple_vote: bool, compute_budget_details: ComputeBudgetDetails, + banking_stage_start_time: Option, } impl ImmutableDeserializedPacket { pub fn new(packet: Packet) -> Result { + let banking_stage_start_time = packet + .meta() + .is_perf_track_packet() + .then_some(Instant::now()); + let versioned_transaction: VersionedTransaction = packet.deserialize_slice(..)?; let sanitized_transaction = SanitizedVersionedTransaction::try_from(versioned_transaction)?; let message_bytes = packet_message(&packet)?; @@ -67,6 +73,7 @@ impl ImmutableDeserializedPacket { message_hash, is_simple_vote, compute_budget_details, + banking_stage_start_time, }) } @@ -98,6 +105,10 @@ impl ImmutableDeserializedPacket { self.compute_budget_details.clone() } + pub fn start_time(&self) -> &Option { + &self.banking_stage_start_time + } + // This function deserializes packets into transactions, computes the blake3 hash of transaction // messages, and verifies secp256k1 instructions. pub fn build_sanitized_transaction( diff --git a/core/src/banking_stage/leader_slot_metrics.rs b/core/src/banking_stage/leader_slot_metrics.rs index 88ea6b5ee340cf..1c255ca019bfe7 100644 --- a/core/src/banking_stage/leader_slot_metrics.rs +++ b/core/src/banking_stage/leader_slot_metrics.rs @@ -936,6 +936,17 @@ impl LeaderSlotMetricsTracker { ); } } + + pub(crate) fn increment_process_sampled_packets_us(&mut self, us: u64) { + if let Some(leader_slot_metrics) = &mut self.leader_slot_metrics { + leader_slot_metrics + .timing_metrics + .process_packets_timings + .process_sampled_packets_us_hist + .increment(us) + .unwrap(); + } + } } #[cfg(test)] diff --git a/core/src/banking_stage/leader_slot_timing_metrics.rs b/core/src/banking_stage/leader_slot_timing_metrics.rs index 7727b6cf6c6563..34ce64b31c34f3 100644 --- a/core/src/banking_stage/leader_slot_timing_metrics.rs +++ b/core/src/banking_stage/leader_slot_timing_metrics.rs @@ -244,6 +244,9 @@ pub(crate) struct ProcessPacketsTimings { // Time spent running the cost model in processing transactions before executing // transactions pub cost_model_us: u64, + + // banking stage processing time histogram for sampled packets + pub process_sampled_packets_us_hist: histogram::Histogram, } impl ProcessPacketsTimings { @@ -264,6 +267,28 @@ impl ProcessPacketsTimings { i64 ), ("cost_model_us", self.cost_model_us, i64), + ( + "process_sampled_packets_us_90pct", + self.process_sampled_packets_us_hist + .percentile(90.0) + .unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_min", + self.process_sampled_packets_us_hist.minimum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_max", + self.process_sampled_packets_us_hist.maximum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_mean", + self.process_sampled_packets_us_hist.mean().unwrap_or(0), + i64 + ), ); } } diff --git a/core/src/banking_stage/unprocessed_transaction_storage.rs b/core/src/banking_stage/unprocessed_transaction_storage.rs index fcc68050b72d4c..52706f8c2bf63b 100644 --- a/core/src/banking_stage/unprocessed_transaction_storage.rs +++ b/core/src/banking_stage/unprocessed_transaction_storage.rs @@ -924,6 +924,7 @@ impl ThreadLocalUnprocessedPackets { .iter() .map(|p| (*p).clone()) .collect_vec(); + let retryable_packets = if let Some(retryable_transaction_indexes) = processing_function(&packets_to_process, payload) { diff --git a/core/src/sigverify_stage.rs b/core/src/sigverify_stage.rs index e5e06a3bc701c9..f41d2b1d192f16 100644 --- a/core/src/sigverify_stage.rs +++ b/core/src/sigverify_stage.rs @@ -18,8 +18,9 @@ use { count_discarded_packets, count_packets_in_batches, count_valid_packets, shrink_batches, }, }, - solana_sdk::timing, + solana_sdk::{signature::Signature, timing}, solana_streamer::streamer::{self, StreamerError}, + solana_transaction_metrics_tracker::get_signature_from_packet, std::{ thread::{self, Builder, JoinHandle}, time::Instant, @@ -78,8 +79,9 @@ struct SigVerifierStats { verify_batches_pp_us_hist: histogram::Histogram, // per-packet time to call verify_batch discard_packets_pp_us_hist: histogram::Histogram, // per-packet time to call verify_batch dedup_packets_pp_us_hist: histogram::Histogram, // per-packet time to call verify_batch - batches_hist: histogram::Histogram, // number of packet batches per verify call - packets_hist: histogram::Histogram, // number of packets per verify call + process_sampled_packets_us_hist: histogram::Histogram, // per-packet time do do overall verify for sampled packets + batches_hist: histogram::Histogram, // number of packet batches per verify call + packets_hist: histogram::Histogram, // number of packets per verify call num_deduper_saturations: usize, total_batches: usize, total_packets: usize, @@ -93,6 +95,7 @@ struct SigVerifierStats { total_discard_random_time_us: usize, total_verify_time_us: usize, total_shrink_time_us: usize, + perf_track_overhead_us: usize, } impl SigVerifierStats { @@ -181,6 +184,28 @@ impl SigVerifierStats { self.dedup_packets_pp_us_hist.mean().unwrap_or(0), i64 ), + ( + "process_sampled_packets_us_90pct", + self.process_sampled_packets_us_hist + .percentile(90.0) + .unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_min", + self.process_sampled_packets_us_hist.minimum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_max", + self.process_sampled_packets_us_hist.maximum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_mean", + self.process_sampled_packets_us_hist.mean().unwrap_or(0), + i64 + ), ( "batches_90pct", self.batches_hist.percentile(90.0).unwrap_or(0), @@ -214,6 +239,7 @@ impl SigVerifierStats { ), ("total_verify_time_us", self.total_verify_time_us, i64), ("total_shrink_time_us", self.total_shrink_time_us, i64), + ("perf_track_overhead_us", self.perf_track_overhead_us, i64), ); } } @@ -296,8 +322,26 @@ impl SigVerifyStage { verifier: &mut T, stats: &mut SigVerifierStats, ) -> Result<(), T::SendType> { + let mut packet_perf_measure: Vec<([u8; 64], std::time::Instant)> = Vec::default(); + let (mut batches, num_packets, recv_duration) = streamer::recv_packet_batches(recvr)?; + let mut start_perf_track_measure = Measure::start("start_perf_track"); + // track sigverify start time for interested packets + for batch in &batches { + for packet in batch.iter() { + if packet.meta().is_perf_track_packet() { + let signature = get_signature_from_packet(packet); + if let Ok(signature) = signature { + packet_perf_measure.push((*signature, Instant::now())); + } + } + } + } + start_perf_track_measure.stop(); + + stats.perf_track_overhead_us = start_perf_track_measure.as_us() as usize; + let batches_len = batches.len(); debug!( "@{:?} verifier: verifying: {}", @@ -370,6 +414,22 @@ impl SigVerifyStage { (num_packets as f32 / verify_time.as_s()) ); + let mut perf_track_end_measure = Measure::start("perf_track_end"); + for (signature, start_time) in packet_perf_measure.iter() { + let duration = Instant::now().duration_since(*start_time); + debug!( + "Sigverify took {duration:?} for transaction {:?}", + Signature::from(*signature) + ); + stats + .process_sampled_packets_us_hist + .increment(duration.as_micros() as u64) + .unwrap(); + } + + perf_track_end_measure.stop(); + stats.perf_track_overhead_us += perf_track_end_measure.as_us() as usize; + stats .recv_batches_us_hist .increment(recv_duration.as_micros() as u64) diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index 93e2a243e2004d..f466c7e879476c 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -4908,6 +4908,7 @@ dependencies = [ "solana-streamer", "solana-svm", "solana-tpu-client", + "solana-transaction-metrics-tracker", "solana-transaction-status", "solana-turbine", "solana-unified-scheduler-pool", @@ -6264,9 +6265,11 @@ dependencies = [ "quinn-proto", "rand 0.8.5", "rustls", + "solana-measure", "solana-metrics", "solana-perf", "solana-sdk", + "solana-transaction-metrics-tracker", "thiserror", "tokio", "x509-parser", @@ -6369,6 +6372,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "solana-transaction-metrics-tracker" +version = "1.19.0" +dependencies = [ + "Inflector", + "base64 0.21.7", + "bincode", + "lazy_static", + "log", + "rand 0.8.5", + "solana-perf", + "solana-sdk", +] + [[package]] name = "solana-transaction-status" version = "1.19.0" diff --git a/sdk/src/packet.rs b/sdk/src/packet.rs index faea9ab4753c67..8300b57218c696 100644 --- a/sdk/src/packet.rs +++ b/sdk/src/packet.rs @@ -33,6 +33,8 @@ bitflags! { /// the packet is built. /// This field can be removed when the above feature gate is adopted by mainnet-beta. const ROUND_COMPUTE_UNIT_PRICE = 0b0010_0000; + /// For tracking performance + const PERF_TRACK_PACKET = 0b0100_0000; } } @@ -228,6 +230,12 @@ impl Meta { self.flags.set(PacketFlags::TRACER_PACKET, is_tracer); } + #[inline] + pub fn set_track_performance(&mut self, is_performance_track: bool) { + self.flags + .set(PacketFlags::PERF_TRACK_PACKET, is_performance_track); + } + #[inline] pub fn set_simple_vote(&mut self, is_simple_vote: bool) { self.flags.set(PacketFlags::SIMPLE_VOTE_TX, is_simple_vote); @@ -261,6 +269,11 @@ impl Meta { self.flags.contains(PacketFlags::TRACER_PACKET) } + #[inline] + pub fn is_perf_track_packet(&self) -> bool { + self.flags.contains(PacketFlags::PERF_TRACK_PACKET) + } + #[inline] pub fn round_compute_unit_price(&self) -> bool { self.flags.contains(PacketFlags::ROUND_COMPUTE_UNIT_PRICE) diff --git a/sdk/src/transaction/versioned/sanitized.rs b/sdk/src/transaction/versioned/sanitized.rs index 61ecdfea56bb2a..b6311d5886b0e3 100644 --- a/sdk/src/transaction/versioned/sanitized.rs +++ b/sdk/src/transaction/versioned/sanitized.rs @@ -33,6 +33,10 @@ impl SanitizedVersionedTransaction { &self.message } + pub fn get_signatures(&self) -> &Vec { + &self.signatures + } + /// Consumes the SanitizedVersionedTransaction, returning the fields individually. pub fn destruct(self) -> (Vec, SanitizedVersionedMessage) { (self.signatures, self.message) diff --git a/streamer/Cargo.toml b/streamer/Cargo.toml index 8e1eb12dff1d42..55d0030e734607 100644 --- a/streamer/Cargo.toml +++ b/streamer/Cargo.toml @@ -26,9 +26,11 @@ quinn = { workspace = true } quinn-proto = { workspace = true } rand = { workspace = true } rustls = { workspace = true, features = ["dangerous_configuration"] } +solana-measure = { workspace = true } solana-metrics = { workspace = true } solana-perf = { workspace = true } solana-sdk = { workspace = true } +solana-transaction-metrics-tracker = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } x509-parser = { workspace = true } diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 225412dd08b315..3485e4fe585d06 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -17,6 +17,7 @@ use { quinn::{Connecting, Connection, Endpoint, EndpointConfig, TokioRuntime, VarInt}, quinn_proto::VarIntBoundsExceeded, rand::{thread_rng, Rng}, + solana_measure::measure::Measure, solana_perf::packet::{PacketBatch, PACKETS_PER_BATCH}, solana_sdk::{ packet::{Meta, PACKET_DATA_SIZE}, @@ -27,9 +28,10 @@ use { QUIC_MIN_STAKED_CONCURRENT_STREAMS, QUIC_MIN_STAKED_RECEIVE_WINDOW_RATIO, QUIC_TOTAL_STAKED_CONCURRENT_STREAMS, QUIC_UNSTAKED_RECEIVE_WINDOW_RATIO, }, - signature::Keypair, + signature::{Keypair, Signature}, timing, }, + solana_transaction_metrics_tracker::signature_if_should_track_packet, std::{ iter::repeat_with, net::{IpAddr, SocketAddr, UdpSocket}, @@ -81,6 +83,7 @@ struct PacketChunk { struct PacketAccumulator { pub meta: Meta, pub chunks: Vec, + pub start_time: Instant, } #[derive(Copy, Clone, Debug)] @@ -628,6 +631,7 @@ async fn packet_batch_sender( trace!("enter packet_batch_sender"); let mut batch_start_time = Instant::now(); loop { + let mut packet_perf_measure: Vec<([u8; 64], std::time::Instant)> = Vec::default(); let mut packet_batch = PacketBatch::with_capacity(PACKETS_PER_BATCH); let mut total_bytes: usize = 0; @@ -647,6 +651,8 @@ async fn packet_batch_sender( || (!packet_batch.is_empty() && elapsed >= coalesce) { let len = packet_batch.len(); + track_streamer_fetch_packet_performance(&mut packet_perf_measure, &stats); + if let Err(e) = packet_sender.send(packet_batch) { stats .total_packet_batch_send_err @@ -692,6 +698,14 @@ async fn packet_batch_sender( total_bytes += packet_batch[i].meta().size; + if let Some(signature) = signature_if_should_track_packet(&packet_batch[i]) + .ok() + .flatten() + { + packet_perf_measure.push((*signature, packet_accumulator.start_time)); + // we set the PERF_TRACK_PACKET on + packet_batch[i].meta_mut().set_track_performance(true); + } stats .total_chunks_processed_by_batcher .fetch_add(num_chunks, Ordering::Relaxed); @@ -700,6 +714,32 @@ async fn packet_batch_sender( } } +fn track_streamer_fetch_packet_performance( + packet_perf_measure: &mut [([u8; 64], Instant)], + stats: &Arc, +) { + if packet_perf_measure.is_empty() { + return; + } + let mut measure = Measure::start("track_perf"); + let mut process_sampled_packets_us_hist = stats.process_sampled_packets_us_hist.lock().unwrap(); + + for (signature, start_time) in packet_perf_measure.iter() { + let duration = Instant::now().duration_since(*start_time); + debug!( + "QUIC streamer fetch stage took {duration:?} for transaction {:?}", + Signature::from(*signature) + ); + process_sampled_packets_us_hist + .increment(duration.as_micros() as u64) + .unwrap(); + } + measure.stop(); + stats + .perf_track_overhead_us + .fetch_add(measure.as_us(), Ordering::Relaxed); +} + async fn handle_connection( connection: Connection, remote_addr: SocketAddr, @@ -854,6 +894,7 @@ async fn handle_chunk( *packet_accum = Some(PacketAccumulator { meta, chunks: Vec::new(), + start_time: Instant::now(), }); } @@ -1453,6 +1494,7 @@ pub mod test { offset, end_of_chunk: size, }], + start_time: Instant::now(), }; ptk_sender.send(packet_accum).await.unwrap(); } diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 69a75532b8ca68..3c9d95b2333c42 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -16,8 +16,8 @@ use { std::{ net::UdpSocket, sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, RwLock, + atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, + Arc, Mutex, RwLock, }, thread, time::{Duration, SystemTime}, @@ -175,10 +175,13 @@ pub struct StreamStats { pub(crate) stream_load_ema: AtomicUsize, pub(crate) stream_load_ema_overflow: AtomicUsize, pub(crate) stream_load_capacity_overflow: AtomicUsize, + pub(crate) process_sampled_packets_us_hist: Mutex, + pub(crate) perf_track_overhead_us: AtomicU64, } impl StreamStats { pub fn report(&self, name: &'static str) { + let process_sampled_packets_us_hist = self.process_sampled_packets_us_hist.lock().unwrap(); datapoint_info!( name, ( @@ -425,6 +428,33 @@ impl StreamStats { self.stream_load_capacity_overflow.load(Ordering::Relaxed), i64 ), + ( + "process_sampled_packets_us_90pct", + process_sampled_packets_us_hist + .percentile(90.0) + .unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_min", + process_sampled_packets_us_hist.minimum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_max", + process_sampled_packets_us_hist.maximum().unwrap_or(0), + i64 + ), + ( + "process_sampled_packets_us_mean", + process_sampled_packets_us_hist.mean().unwrap_or(0), + i64 + ), + ( + "perf_track_overhead_us", + self.perf_track_overhead_us.swap(0, Ordering::Relaxed), + i64 + ), ); } } diff --git a/transaction-metrics-tracker/Cargo.toml b/transaction-metrics-tracker/Cargo.toml new file mode 100644 index 00000000000000..9bd82702a3ebb4 --- /dev/null +++ b/transaction-metrics-tracker/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "solana-transaction-metrics-tracker" +description = "Solana transaction metrics tracker" +documentation = "https://docs.rs/solana-transaction-metrics-tracker" +version = { workspace = true } +authors = { workspace = true } +repository = { workspace = true } +homepage = { workspace = true } +license = { workspace = true } +edition = { workspace = true } +publish = false + +[dependencies] +Inflector = { workspace = true } +base64 = { workspace = true } +bincode = { workspace = true } +# Update this borsh dependency to the workspace version once +lazy_static = { workspace = true } +log = { workspace = true } +rand = { workspace = true } +solana-perf = { workspace = true } +solana-sdk = { workspace = true } + +[package.metadata.docs.rs] +targets = ["x86_64-unknown-linux-gnu"] diff --git a/transaction-metrics-tracker/src/lib.rs b/transaction-metrics-tracker/src/lib.rs new file mode 100644 index 00000000000000..2baec195de9b84 --- /dev/null +++ b/transaction-metrics-tracker/src/lib.rs @@ -0,0 +1,157 @@ +use { + lazy_static::lazy_static, + log::*, + rand::Rng, + solana_perf::sigverify::PacketError, + solana_sdk::{packet::Packet, short_vec::decode_shortu16_len, signature::SIGNATURE_BYTES}, +}; + +// The mask is 12 bits long (1<<12 = 4096), it means the probability of matching +// the transaction is 1/4096 assuming the portion being matched is random. +lazy_static! { + static ref TXN_MASK: u16 = rand::thread_rng().gen_range(0..4096); +} + +/// Check if a transaction given its signature matches the randomly selected mask. +/// The signaure should be from the reference of Signature +pub fn should_track_transaction(signature: &[u8; SIGNATURE_BYTES]) -> bool { + // We do not use the highest signature byte as it is not really random + let match_portion: u16 = u16::from_le_bytes([signature[61], signature[62]]) >> 4; + trace!("Matching txn: {match_portion:016b} {:016b}", *TXN_MASK); + *TXN_MASK == match_portion +} + +/// Check if a transaction packet's signature matches the mask. +/// This does a rudimentry verification to make sure the packet at least +/// contains the signature data and it returns the reference to the signature. +pub fn signature_if_should_track_packet( + packet: &Packet, +) -> Result, PacketError> { + let signature = get_signature_from_packet(packet)?; + Ok(should_track_transaction(signature).then_some(signature)) +} + +/// Get the signature of the transaction packet +/// This does a rudimentry verification to make sure the packet at least +/// contains the signature data and it returns the reference to the signature. +pub fn get_signature_from_packet(packet: &Packet) -> Result<&[u8; SIGNATURE_BYTES], PacketError> { + let (sig_len_untrusted, sig_start) = packet + .data(..) + .and_then(|bytes| decode_shortu16_len(bytes).ok()) + .ok_or(PacketError::InvalidShortVec)?; + + if sig_len_untrusted < 1 { + return Err(PacketError::InvalidSignatureLen); + } + + let signature = packet + .data(sig_start..sig_start.saturating_add(SIGNATURE_BYTES)) + .ok_or(PacketError::InvalidSignatureLen)?; + let signature = signature + .try_into() + .map_err(|_| PacketError::InvalidSignatureLen)?; + Ok(signature) +} + +#[cfg(test)] +mod tests { + use { + super::*, + solana_sdk::{ + hash::Hash, + signature::{Keypair, Signature}, + system_transaction, + }, + }; + + #[test] + fn test_get_signature_from_packet() { + // Default invalid txn packet + let packet = Packet::default(); + let sig = get_signature_from_packet(&packet); + assert_eq!(sig, Err(PacketError::InvalidShortVec)); + + // Use a valid transaction, it should succeed + let tx = system_transaction::transfer( + &Keypair::new(), + &solana_sdk::pubkey::new_rand(), + 1, + Hash::new_unique(), + ); + let mut packet = Packet::from_data(None, tx).unwrap(); + + let sig = get_signature_from_packet(&packet); + assert!(sig.is_ok()); + + // Invalid signature length + packet.buffer_mut()[0] = 0x0; + let sig = get_signature_from_packet(&packet); + assert_eq!(sig, Err(PacketError::InvalidSignatureLen)); + } + + #[test] + fn test_should_track_transaction() { + let mut sig = [0x0; SIGNATURE_BYTES]; + let track = should_track_transaction(&sig); + assert!(!track); + + // Intentionally matching the randomly generated mask + // The lower four bits are ignored as only 12 highest bits from + // signature's 61 and 62 u8 are used for matching. + // We generate a random one + let mut rng = rand::thread_rng(); + let random_number: u8 = rng.gen_range(0..=15); + sig[61] = ((*TXN_MASK & 0xf_u16) << 4) as u8 | random_number; + sig[62] = (*TXN_MASK >> 4) as u8; + + let track = should_track_transaction(&sig); + assert!(track); + } + + #[test] + fn test_signature_if_should_track_packet() { + // Default invalid txn packet + let packet = Packet::default(); + let sig = signature_if_should_track_packet(&packet); + assert_eq!(sig, Err(PacketError::InvalidShortVec)); + + // Use a valid transaction which is not matched + let tx = system_transaction::transfer( + &Keypair::new(), + &solana_sdk::pubkey::new_rand(), + 1, + Hash::new_unique(), + ); + let packet = Packet::from_data(None, tx).unwrap(); + let sig = signature_if_should_track_packet(&packet); + assert_eq!(Ok(None), sig); + + // Now simulate a txn matching the signature mask + let mut tx = system_transaction::transfer( + &Keypair::new(), + &solana_sdk::pubkey::new_rand(), + 1, + Hash::new_unique(), + ); + let mut sig = [0x0; SIGNATURE_BYTES]; + sig[61] = ((*TXN_MASK & 0xf_u16) << 4) as u8; + sig[62] = (*TXN_MASK >> 4) as u8; + + let sig = Signature::from(sig); + tx.signatures[0] = sig; + let mut packet = Packet::from_data(None, tx).unwrap(); + let sig2 = signature_if_should_track_packet(&packet); + + match sig2 { + Ok(sig) => { + assert!(sig.is_some()); + } + Err(_) => panic!("Expected to get a matching signature!"), + } + + // Invalid signature length + packet.buffer_mut()[0] = 0x0; + let sig = signature_if_should_track_packet(&packet); + assert_eq!(sig, Err(PacketError::InvalidSignatureLen)); + } +}