From 7bf9be39ef21c35c64d170efe4d75e3201f65a3c Mon Sep 17 00:00:00 2001 From: Ammar Arif Date: Tue, 5 Nov 2024 17:57:15 -0500 Subject: [PATCH] feat(katana): retain transactions in pool until mined (#2630) Every block interval, the node would take transactions from the pool - removing it directly from the pool. this creates a small window (depending on the machine) that the transaction appears nonexistent. this is due to how the tx flows from the pool and executor. this applies for both instant and interval block production mode. For instant mining, the window is between tx being picked up from the pool and the tx being committed to db. while for interval, tx being picked up from the pool and the tx being [inserted into the pending block](https://github.com/dojoengine/dojo/blob/d09cbcffd8c8f2745770888f9d3f30d07b8555ae/crates/katana/executor/src/implementation/blockifier/mod.rs#L208). When a tx is being queried thru the rpc, the node will first check if the it exist in the db, else find in the pending block (if interval mode). this pr adds a new (last) step, which is to try finding the tx in the pool if it doesn't exist anywhere else. --- Cargo.lock | 1 + crates/katana/core/src/service/mod.rs | 75 ++++++---- .../katana/pipeline/src/stage/sequencing.rs | 7 +- crates/katana/pool/Cargo.toml | 3 +- crates/katana/pool/src/lib.rs | 14 +- crates/katana/pool/src/ordering.rs | 66 +++++---- crates/katana/pool/src/pending.rs | 139 ++++++++++++++++++ crates/katana/pool/src/pool.rs | 139 +++++++++++++----- crates/katana/pool/src/subscription.rs | 67 +++++++++ crates/katana/rpc/rpc/src/starknet/mod.rs | 11 +- 10 files changed, 412 insertions(+), 110 deletions(-) create mode 100644 crates/katana/pool/src/pending.rs create mode 100644 crates/katana/pool/src/subscription.rs diff --git a/Cargo.lock b/Cargo.lock index e85734ffff..294f2dd455 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8273,6 +8273,7 @@ name = "katana-pool" version = "1.0.0-rc.1" dependencies = [ "futures", + "futures-util", "katana-executor", "katana-primitives", "katana-provider", diff --git a/crates/katana/core/src/service/mod.rs b/crates/katana/core/src/service/mod.rs index 0d368ff9de..8bdae8bb7f 100644 --- a/crates/katana/core/src/service/mod.rs +++ b/crates/katana/core/src/service/mod.rs @@ -1,17 +1,14 @@ -// TODO: remove the messaging feature flag -// TODO: move the tasks to a separate module - use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use block_producer::BlockProductionError; -use futures::channel::mpsc::Receiver; -use futures::stream::{Fuse, Stream, StreamExt}; +use futures::stream::StreamExt; use katana_executor::ExecutorFactory; +use katana_pool::ordering::PoolOrd; +use katana_pool::pending::PendingTransactions; use katana_pool::{TransactionPool, TxPool}; use katana_primitives::transaction::ExecutableTxWithHash; -use katana_primitives::Felt; use tracing::{error, info}; use self::block_producer::BlockProducer; @@ -30,24 +27,40 @@ pub(crate) const LOG_TARGET: &str = "node"; /// to construct a new block. #[must_use = "BlockProductionTask does nothing unless polled"] #[allow(missing_debug_implementations)] -pub struct BlockProductionTask { +pub struct BlockProductionTask +where + EF: ExecutorFactory, + O: PoolOrd, +{ /// creates new blocks pub(crate) block_producer: BlockProducer, /// the miner responsible to select transactions from the `pool´ - pub(crate) miner: TransactionMiner, + pub(crate) miner: TransactionMiner, /// the pool that holds all transactions pub(crate) pool: TxPool, /// Metrics for recording the service operations metrics: BlockProducerMetrics, } -impl BlockProductionTask { - pub fn new(pool: TxPool, miner: TransactionMiner, block_producer: BlockProducer) -> Self { +impl BlockProductionTask +where + EF: ExecutorFactory, + O: PoolOrd, +{ + pub fn new( + pool: TxPool, + miner: TransactionMiner, + block_producer: BlockProducer, + ) -> Self { Self { block_producer, miner, pool, metrics: BlockProducerMetrics::default() } } } -impl Future for BlockProductionTask { +impl Future for BlockProductionTask +where + EF: ExecutorFactory, + O: PoolOrd, +{ type Output = Result<(), BlockProductionError>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -65,6 +78,9 @@ impl Future for BlockProductionTask { let steps_used = outcome.stats.cairo_steps_used; this.metrics.l1_gas_processed_total.increment(gas_used as u64); this.metrics.cairo_steps_processed_total.increment(steps_used as u64); + + // remove mined transactions from the pool + this.pool.remove_transactions(&outcome.txs); } Err(error) => { @@ -74,7 +90,7 @@ impl Future for BlockProductionTask { } } - if let Poll::Ready(pool_txs) = this.miner.poll(&this.pool, cx) { + if let Poll::Ready(pool_txs) = this.miner.poll(cx) { // miner returned a set of transaction that we feed to the producer this.block_producer.queue(pool_txs); } else { @@ -89,37 +105,32 @@ impl Future for BlockProductionTask { /// The type which takes the transaction from the pool and feeds them to the block producer. #[derive(Debug)] -pub struct TransactionMiner { - /// stores whether there are pending transacions (if known) - has_pending_txs: Option, - /// Receives hashes of transactions that are ready from the pool - rx: Fuse>, +pub struct TransactionMiner +where + O: PoolOrd, +{ + pending_txs: PendingTransactions, } -impl TransactionMiner { - pub fn new(rx: Receiver) -> Self { - Self { rx: rx.fuse(), has_pending_txs: None } +impl TransactionMiner +where + O: PoolOrd, +{ + pub fn new(pending_txs: PendingTransactions) -> Self { + Self { pending_txs } } - fn poll(&mut self, pool: &TxPool, cx: &mut Context<'_>) -> Poll> { - // drain the notification stream - while let Poll::Ready(Some(_)) = Pin::new(&mut self.rx).poll_next(cx) { - self.has_pending_txs = Some(true); - } + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { + let mut transactions = Vec::new(); - if self.has_pending_txs == Some(false) { - return Poll::Pending; + while let Poll::Ready(Some(tx)) = self.pending_txs.poll_next_unpin(cx) { + transactions.push(tx.tx.as_ref().clone()); } - // take all the transactions from the pool - let transactions = - pool.take_transactions().map(|tx| tx.tx.as_ref().clone()).collect::>(); - if transactions.is_empty() { return Poll::Pending; } - self.has_pending_txs = Some(false); Poll::Ready(transactions) } } diff --git a/crates/katana/pipeline/src/stage/sequencing.rs b/crates/katana/pipeline/src/stage/sequencing.rs index ad351b3344..c988ecab57 100644 --- a/crates/katana/pipeline/src/stage/sequencing.rs +++ b/crates/katana/pipeline/src/stage/sequencing.rs @@ -53,11 +53,10 @@ impl Sequencing { } fn run_block_production(&self) -> TaskHandle> { - let pool = self.pool.clone(); - let miner = TransactionMiner::new(pool.add_listener()); + // Create a new transaction miner with a subscription to the pool's pending transactions. + let miner = TransactionMiner::new(self.pool.pending_transactions()); let block_producer = self.block_producer.clone(); - - let service = BlockProductionTask::new(pool, miner, block_producer); + let service = BlockProductionTask::new(self.pool.clone(), miner, block_producer); self.task_spawner.build_task().name("Block production").spawn(service) } } diff --git a/crates/katana/pool/Cargo.toml b/crates/katana/pool/Cargo.toml index 256f6ce928..207b18cbc6 100644 --- a/crates/katana/pool/Cargo.toml +++ b/crates/katana/pool/Cargo.toml @@ -13,8 +13,9 @@ katana-primitives.workspace = true katana-provider.workspace = true parking_lot.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = [ "sync" ] } tracing.workspace = true [dev-dependencies] +futures-util.workspace = true rand.workspace = true -tokio.workspace = true diff --git a/crates/katana/pool/src/lib.rs b/crates/katana/pool/src/lib.rs index 5a607907b1..f61d458104 100644 --- a/crates/katana/pool/src/lib.rs +++ b/crates/katana/pool/src/lib.rs @@ -1,7 +1,9 @@ #![cfg_attr(not(test), warn(unused_crate_dependencies))] pub mod ordering; +pub mod pending; pub mod pool; +pub mod subscription; pub mod tx; pub mod validation; @@ -10,8 +12,9 @@ use std::sync::Arc; use futures::channel::mpsc::Receiver; use katana_primitives::transaction::{ExecutableTxWithHash, TxHash}; use ordering::{FiFo, PoolOrd}; +use pending::PendingTransactions; use pool::Pool; -use tx::{PendingTx, PoolTransaction}; +use tx::PoolTransaction; use validation::error::InvalidTransactionError; use validation::stateful::TxValidator; use validation::Validator; @@ -44,9 +47,9 @@ pub trait TransactionPool { /// Add a new transaction to the pool. fn add_transaction(&self, tx: Self::Transaction) -> PoolResult; - fn take_transactions( - &self, - ) -> impl Iterator>; + /// Returns a [`Stream`](futures::Stream) which yields pending transactions - transactions that + /// can be executed - from the pool. + fn pending_transactions(&self) -> PendingTransactions; /// Check if the pool contains a transaction with the given hash. fn contains(&self, hash: TxHash) -> bool; @@ -56,6 +59,9 @@ pub trait TransactionPool { fn add_listener(&self) -> Receiver; + /// Removes a list of transactions from the pool according to their hashes. + fn remove_transactions(&self, hashes: &[TxHash]); + /// Get the total number of transactions in the pool. fn size(&self) -> usize; diff --git a/crates/katana/pool/src/ordering.rs b/crates/katana/pool/src/ordering.rs index 25a7794082..e0bad9ebec 100644 --- a/crates/katana/pool/src/ordering.rs +++ b/crates/katana/pool/src/ordering.rs @@ -125,14 +125,16 @@ impl Default for TipOrdering { #[cfg(test)] mod tests { + use futures::StreamExt; + use crate::ordering::{self, FiFo}; use crate::pool::test_utils::*; use crate::tx::PoolTransaction; use crate::validation::NoopValidator; use crate::{Pool, TransactionPool}; - #[test] - fn fifo_ordering() { + #[tokio::test] + async fn fifo_ordering() { // Create mock transactions let txs = [PoolTx::new(), PoolTx::new(), PoolTx::new(), PoolTx::new(), PoolTx::new()]; @@ -145,16 +147,17 @@ mod tests { }); // Get pending transactions - let pendings = pool.take_transactions().collect::>(); + let mut pendings = pool.pending_transactions(); // Assert that the transactions are in the order they were added (first to last) - pendings.iter().zip(txs).for_each(|(pending, tx)| { + for tx in txs { + let pending = pendings.next().await.unwrap(); assert_eq!(pending.tx.as_ref(), &tx); - }); + } } - #[test] - fn tip_based_ordering() { + #[tokio::test] + async fn tip_based_ordering() { // Create mock transactions with different tips and in random order let txs = [ PoolTx::new().with_tip(2), @@ -176,36 +179,43 @@ mod tests { let _ = pool.add_transaction(tx.clone()); }); - // Get pending transactions - let pending = pool.take_transactions().collect::>(); - assert_eq!(pending.len(), txs.len()); + let mut pending = pool.pending_transactions(); // Assert that the transactions are ordered by tip (highest to lowest) - assert_eq!(pending[0].tx.tip(), 7); - assert_eq!(pending[0].tx.hash(), txs[8].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 7); + assert_eq!(tx.tx.hash(), txs[8].hash()); - assert_eq!(pending[1].tx.tip(), 6); - assert_eq!(pending[1].tx.hash(), txs[2].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 6); + assert_eq!(tx.tx.hash(), txs[2].hash()); - assert_eq!(pending[2].tx.tip(), 5); - assert_eq!(pending[2].tx.hash(), txs[6].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 5); + assert_eq!(tx.tx.hash(), txs[6].hash()); - assert_eq!(pending[3].tx.tip(), 4); - assert_eq!(pending[3].tx.hash(), txs[7].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 4); + assert_eq!(tx.tx.hash(), txs[7].hash()); - assert_eq!(pending[4].tx.tip(), 3); - assert_eq!(pending[4].tx.hash(), txs[3].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 3); + assert_eq!(tx.tx.hash(), txs[3].hash()); - assert_eq!(pending[5].tx.tip(), 2); - assert_eq!(pending[5].tx.hash(), txs[0].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 2); + assert_eq!(tx.tx.hash(), txs[0].hash()); - assert_eq!(pending[6].tx.tip(), 2); - assert_eq!(pending[6].tx.hash(), txs[4].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 2); + assert_eq!(tx.tx.hash(), txs[4].hash()); - assert_eq!(pending[7].tx.tip(), 2); - assert_eq!(pending[7].tx.hash(), txs[5].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 2); + assert_eq!(tx.tx.hash(), txs[5].hash()); - assert_eq!(pending[8].tx.tip(), 1); - assert_eq!(pending[8].tx.hash(), txs[1].hash()); + let tx = pending.next().await.unwrap(); + assert_eq!(tx.tx.tip(), 1); + assert_eq!(tx.tx.hash(), txs[1].hash()); } } diff --git a/crates/katana/pool/src/pending.rs b/crates/katana/pool/src/pending.rs new file mode 100644 index 0000000000..82b6c76c68 --- /dev/null +++ b/crates/katana/pool/src/pending.rs @@ -0,0 +1,139 @@ +use std::collections::btree_set::IntoIter; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Stream, StreamExt}; + +use crate::ordering::PoolOrd; +use crate::subscription::Subscription; +use crate::tx::{PendingTx, PoolTransaction}; + +/// An iterator that yields transactions from the pool that can be included in a block, sorted by +/// by its priority. +#[derive(Debug)] +pub struct PendingTransactions { + /// Iterator over all the pending transactions at the time of the creation of this struct. + pub(crate) all: IntoIter>, + /// Subscription to the pool to get notified when new transactions are added. This is used to + /// wait on the new transactions after exhausting the `all` iterator. + pub(crate) subscription: Subscription, +} + +impl Stream for PendingTransactions +where + T: PoolTransaction, + O: PoolOrd, +{ + type Item = PendingTx; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + if let Some(tx) = this.all.next() { + Poll::Ready(Some(tx)) + } else { + this.subscription.poll_next_unpin(cx) + } + } +} + +#[cfg(test)] +mod tests { + + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + use futures::StreamExt; + use tokio::task::yield_now; + + use crate::pool::test_utils::PoolTx; + use crate::pool::Pool; + use crate::validation::NoopValidator; + use crate::{ordering, PoolTransaction, TransactionPool}; + + #[tokio::test] + async fn pending_transactions() { + let pool = Pool::new(NoopValidator::::new(), ordering::FiFo::new()); + + let first_batch = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + for tx in &first_batch { + pool.add_transaction(tx.clone()).expect("failed to add tx"); + } + + let mut pendings = pool.pending_transactions(); + + // exhaust all the first batch transactions + for expected in &first_batch { + let actual = pendings.next().await.map(|t| t.tx).unwrap(); + assert_eq!(expected, actual.as_ref()); + } + + let second_batch = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + for tx in &second_batch { + pool.add_transaction(tx.clone()).expect("failed to add tx"); + } + + // exhaust all the first batch transactions + for expected in &second_batch { + let actual = pendings.next().await.map(|t| t.tx).unwrap(); + assert_eq!(expected, actual.as_ref()); + } + + // Check that all the added transaction is still in the pool because we haven't removed it + // yet. + let all = [first_batch, second_batch].concat(); + for tx in all { + assert!(pool.contains(tx.hash())); + } + } + + #[tokio::test(flavor = "multi_thread")] + async fn subscription_stream_wakeup() { + let pool = Pool::new(NoopValidator::::new(), ordering::FiFo::new()); + let mut pending = pool.pending_transactions(); + + // Spawn a task that will add a transaction after a delay + let pool_clone = pool.clone(); + + let txs = [PoolTx::new(), PoolTx::new(), PoolTx::new()]; + let txs_clone = txs.clone(); + + let has_polled_once = Arc::new(AtomicBool::new(false)); + let has_polled_once_clone = has_polled_once.clone(); + + tokio::spawn(async move { + while !has_polled_once_clone.load(Ordering::SeqCst) { + yield_now().await; + } + + for tx in txs_clone { + pool_clone.add_transaction(tx).expect("failed to add tx"); + } + }); + + // Check that first poll_next returns Pending because no pending transaction has been added + // to the pool yet + assert!(futures_util::poll!(pending.next()).is_pending()); + has_polled_once.store(true, Ordering::SeqCst); + + for tx in txs { + let received = pending.next().await.unwrap(); + assert_eq!(&tx, received.tx.as_ref()); + } + } +} diff --git a/crates/katana/pool/src/pool.rs b/crates/katana/pool/src/pool.rs index 1031f17a0f..f047e94467 100644 --- a/crates/katana/pool/src/pool.rs +++ b/crates/katana/pool/src/pool.rs @@ -1,14 +1,16 @@ use core::fmt; -use std::collections::btree_set::IntoIter; use std::collections::BTreeSet; use std::sync::Arc; use futures::channel::mpsc::{channel, Receiver, Sender}; use katana_primitives::transaction::TxHash; use parking_lot::RwLock; +use tokio::sync::mpsc; use tracing::{error, info, warn}; use crate::ordering::PoolOrd; +use crate::pending::PendingTransactions; +use crate::subscription::Subscription; use crate::tx::{PendingTx, PoolTransaction, TxId}; use crate::validation::error::InvalidTransactionError; use crate::validation::{ValidationOutcome, Validator}; @@ -32,6 +34,9 @@ struct Inner { /// listeners for incoming txs listeners: RwLock>>, + /// subscribers for incoming txs + subscribers: RwLock>>>, + /// the tx validator validator: V, @@ -52,6 +57,7 @@ where ordering, validator, transactions: Default::default(), + subscribers: Default::default(), listeners: Default::default(), }), } @@ -83,6 +89,38 @@ where } } } + + fn notify_subscribers(&self, tx: PendingTx) { + let mut subscribers = self.inner.subscribers.write(); + // this is basically a retain but with mut reference + for n in (0..subscribers.len()).rev() { + let sender = subscribers.swap_remove(n); + let retain = match sender.send(tx.clone()) { + Ok(()) => true, + Err(error) => { + warn!(%error, "Subscription channel closed"); + false + } + }; + + if retain { + subscribers.push(sender) + } + } + } + + // notify both listener and subscribers + fn notify(&self, tx: PendingTx) { + self.notify_listener(tx.tx.hash()); + self.notify_subscribers(tx); + } + + fn subscribe(&self) -> Subscription { + let (tx, rx) = mpsc::unbounded_channel(); + let subscriber = Subscription::new(rx); + self.inner.subscribers.write().push(tx); + subscriber + } } impl TransactionPool for Pool @@ -110,12 +148,14 @@ where let tx = PendingTx::new(id, tx, priority); // insert the tx in the pool - self.inner.transactions.write().insert(tx); - self.notify_listener(hash); + self.inner.transactions.write().insert(tx.clone()); + self.notify(tx); Ok(hash) } + // TODO: create a small cache for rejected transactions to respect the rpc spec + // `getTransactionStatus` ValidationOutcome::Invalid { error, .. } => { warn!(hash = format!("{hash:#x}"), "Invalid transaction."); Err(PoolError::InvalidTransaction(Box::new(error))) @@ -142,10 +182,11 @@ where } } - fn take_transactions(&self) -> impl Iterator> { + fn pending_transactions(&self) -> PendingTransactions { // take all the transactions PendingTransactions { - all: std::mem::take(&mut *self.inner.transactions.write()).into_iter(), + subscription: self.subscribe(), + all: self.inner.transactions.read().clone().into_iter(), } } @@ -170,6 +211,12 @@ where rx } + fn remove_transactions(&self, hashes: &[TxHash]) { + // retain only transactions that aren't included in the list + let mut txs = self.inner.transactions.write(); + txs.retain(|t| !hashes.contains(&t.tx.hash())) + } + fn size(&self) -> usize { self.inner.transactions.read().len() } @@ -190,24 +237,6 @@ where } } -/// an iterator that yields transactions from the pool that can be included in a block, sorted by -/// by its priority. -struct PendingTransactions { - all: IntoIter>, -} - -impl Iterator for PendingTransactions -where - T: PoolTransaction, - O: PoolOrd, -{ - type Item = PendingTx; - - fn next(&mut self) -> Option { - self.all.next() - } -} - #[cfg(test)] pub(crate) mod test_utils { @@ -290,7 +319,9 @@ pub(crate) mod test_utils { #[cfg(test)] mod tests { + use futures::StreamExt; use katana_primitives::contract::{ContractAddress, Nonce}; + use katana_primitives::transaction::TxHash; use katana_primitives::Felt; use super::test_utils::*; @@ -309,8 +340,8 @@ mod tests { } } - #[test] - fn pool_operations() { + #[tokio::test] + async fn pool_operations() { let txs = [ PoolTx::new(), PoolTx::new(), @@ -339,12 +370,12 @@ mod tests { assert!(txs.iter().all(|tx| pool.get(tx.hash()).is_some())); // noop validator should consider all txs as valid - let pendings = pool.take_transactions().collect::>(); - assert_eq!(pendings.len(), txs.len()); + let mut pendings = pool.pending_transactions(); // bcs we're using fcfs, the order should be the same as the order of the txs submission // (position in the array) - for (actual, expected) in pendings.iter().zip(txs.iter()) { + for expected in &txs { + let actual = pendings.next().await.unwrap(); assert_eq!(actual.tx.tip(), expected.tip()); assert_eq!(actual.tx.hash(), expected.hash()); assert_eq!(actual.tx.nonce(), expected.nonce()); @@ -352,8 +383,9 @@ mod tests { assert_eq!(actual.tx.max_fee(), expected.max_fee()); } - // take all transactions - let _ = pool.take_transactions(); + // remove all transactions + let hashes = txs.iter().map(|t| t.hash()).collect::>(); + pool.remove_transactions(&hashes); // all txs should've been removed assert!(pool.size() == 0); @@ -395,8 +427,43 @@ mod tests { } #[test] + fn remove_transactions() { + let pool = TestPool::test(); + + let txs = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + // start adding txs to the pool + txs.iter().for_each(|tx| { + let _ = pool.add_transaction(tx.clone()); + }); + + // first check that the transaction are indeed in the pool + txs.iter().for_each(|tx| { + assert!(pool.contains(tx.hash())); + }); + + // remove the transactions + let hashes = txs.iter().map(|t| t.hash()).collect::>(); + pool.remove_transactions(&hashes); + + // check that the transaction are no longer in the pool + txs.iter().for_each(|tx| { + assert!(!pool.contains(tx.hash())); + }); + } + + #[tokio::test] #[ignore = "Txs dependency management not fully implemented yet"] - fn dependent_txs_linear_insertion() { + async fn dependent_txs_linear_insertion() { let pool = TestPool::test(); // Create 100 transactions with the same sender but increasing nonce @@ -412,14 +479,12 @@ mod tests { }); // Get pending transactions - let pending = pool.take_transactions().collect::>(); - - // Check that the number of pending transactions matches the number of added transactions - assert_eq!(pending.len(), total as usize); + let mut pendings = pool.pending_transactions(); // Check that the pending transactions are in the same order as they were added - for (i, pending_tx) in pending.iter().enumerate() { - assert_eq!(pending_tx.tx.nonce(), Nonce::from(i as u128)); + for i in 0..total { + let pending_tx = pendings.next().await.unwrap(); + assert_eq!(pending_tx.tx.nonce(), Nonce::from(i)); assert_eq!(pending_tx.tx.sender(), sender); } } diff --git a/crates/katana/pool/src/subscription.rs b/crates/katana/pool/src/subscription.rs new file mode 100644 index 0000000000..465cffba97 --- /dev/null +++ b/crates/katana/pool/src/subscription.rs @@ -0,0 +1,67 @@ +use std::collections::BTreeSet; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::Stream; +use parking_lot::Mutex; +use tokio::sync::mpsc; + +use crate::ordering::PoolOrd; +use crate::tx::{PendingTx, PoolTransaction}; + +#[derive(Debug)] +pub struct Subscription { + txs: Mutex>>, + receiver: mpsc::UnboundedReceiver>, +} + +impl Subscription +where + T: PoolTransaction, + O: PoolOrd, +{ + pub(crate) fn new(receiver: mpsc::UnboundedReceiver>) -> Self { + Self { txs: Default::default(), receiver } + } +} + +impl Stream for Subscription +where + T: PoolTransaction, + O: PoolOrd, +{ + type Item = PendingTx; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut txs = this.txs.lock(); + + // In the event where a lot of transactions have been sent to the receiver channel and this + // stream hasn't been iterated since, the next call to `.next()` of this Stream will + // require to drain the channel and insert all the transactions into the btree set. If there + // are a lot of transactions to insert, it would take a while and might block the + // runtime. + loop { + if let Some(tx) = txs.pop_first() { + return Poll::Ready(Some(tx)); + } + + // Check the channel if there are new transactions available. + match this.receiver.poll_recv(cx) { + // insert the new transactions into the btree set to make sure they are ordered + // according to the pool's ordering. + Poll::Ready(Some(tx)) => { + txs.insert(tx); + + // Check if there are more transactions available in the channel. + while let Poll::Ready(Some(tx)) = this.receiver.poll_recv(cx) { + txs.insert(tx); + } + } + + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, + } + } + } +} diff --git a/crates/katana/rpc/rpc/src/starknet/mod.rs b/crates/katana/rpc/rpc/src/starknet/mod.rs index be7f65200c..c06c778d50 100644 --- a/crates/katana/rpc/rpc/src/starknet/mod.rs +++ b/crates/katana/rpc/rpc/src/starknet/mod.rs @@ -12,7 +12,7 @@ use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::{ExecutionResult, ExecutorFactory}; use katana_pool::validation::stateful::TxValidator; -use katana_pool::TxPool; +use katana_pool::{TransactionPool, TxPool}; use katana_primitives::block::{ BlockHash, BlockHashOrNumber, BlockIdOrTag, BlockNumber, BlockTag, FinalityStatus, PartialHeader, @@ -23,7 +23,7 @@ use katana_primitives::conversion::rpc::legacy_inner_to_rpc_class; use katana_primitives::da::L1DataAvailabilityMode; use katana_primitives::env::BlockEnv; use katana_primitives::event::MaybeForkedContinuationToken; -use katana_primitives::transaction::{ExecutableTxWithHash, TxHash}; +use katana_primitives::transaction::{ExecutableTxWithHash, TxHash, TxWithHash}; use katana_primitives::Felt; use katana_provider::traits::block::{BlockHashProvider, BlockIdReader, BlockNumberProvider}; use katana_provider::traits::contract::ContractClassProvider; @@ -431,7 +431,9 @@ impl StarknetApi { } else if let Some(client) = &self.inner.forked_client { Ok(client.get_transaction_by_hash(hash).await?) } else { - Err(StarknetApiError::TxnHashNotFound) + let tx = self.inner.pool.get(hash).ok_or(StarknetApiError::TxnHashNotFound)?; + let tx = TxWithHash::from(tx.as_ref()); + Ok(Tx::from(tx)) } } @@ -563,7 +565,8 @@ impl StarknetApi { } else if let Some(client) = &self.inner.forked_client { Ok(client.get_transaction_status(hash).await?) } else { - Err(StarknetApiError::TxnHashNotFound) + let _ = self.inner.pool.get(hash).ok_or(StarknetApiError::TxnHashNotFound)?; + Ok(TransactionStatus::Received) } }