Skip to content

Commit

Permalink
Merge pull request #201 from weiznich/fix/198
Browse files Browse the repository at this point in the history
Fix #198
  • Loading branch information
weiznich authored Nov 26, 2024
2 parents 35cb1ad + e857edf commit e3beac6
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 177 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
matrix:
rust: ["stable", "beta", "nightly"]
backend: ["postgres", "mysql", "sqlite"]
os: [ubuntu-latest, macos-13, macos-14, windows-2019]
os: [ubuntu-latest, macos-13, macos-15, windows-2019]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout sources
Expand Down Expand Up @@ -121,7 +121,7 @@ jobs:
echo "DATABASE_URL=postgres://postgres@localhost/" >> $GITHUB_ENV
- name: Install postgres (MacOS M1)
if: matrix.os == 'macos-14' && matrix.backend == 'postgres'
if: matrix.os == 'macos-15' && matrix.backend == 'postgres'
run: |
brew install postgresql@14
brew services start postgresql@14
Expand All @@ -138,24 +138,24 @@ jobs:
- name: Install mysql (MacOS)
if: matrix.os == 'macos-13' && matrix.backend == 'mysql'
run: |
brew install mariadb@11.2
/usr/local/opt/mariadb@11.2/bin/mysql_install_db
/usr/local/opt/mariadb@11.2/bin/mysql.server start
brew install mariadb@11.4
/usr/local/opt/mariadb@11.4/bin/mysql_install_db
/usr/local/opt/mariadb@11.4/bin/mysql.server start
sleep 3
/usr/local/opt/mariadb@11.2/bin/mysqladmin -u runner password diesel
/usr/local/opt/mariadb@11.2/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
/usr/local/opt/mariadb@11.4/bin/mysqladmin -u runner password diesel
/usr/local/opt/mariadb@11.4/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
echo "DATABASE_URL=mysql://runner:diesel@localhost/diesel_test" >> $GITHUB_ENV
- name: Install mysql (MacOS M1)
if: matrix.os == 'macos-14' && matrix.backend == 'mysql'
if: matrix.os == 'macos-15' && matrix.backend == 'mysql'
run: |
brew install mariadb@11.2
ls /opt/homebrew/opt/mariadb@11.2
/opt/homebrew/opt/mariadb@11.2/bin/mysql_install_db
/opt/homebrew/opt/mariadb@11.2/bin/mysql.server start
brew install mariadb@11.4
ls /opt/homebrew/opt/mariadb@11.4
/opt/homebrew/opt/mariadb@11.4/bin/mysql_install_db
/opt/homebrew/opt/mariadb@11.4/bin/mysql.server start
sleep 3
/opt/homebrew/opt/mariadb@11.2/bin/mysqladmin -u runner password diesel
/opt/homebrew/opt/mariadb@11.2/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
/opt/homebrew/opt/mariadb@11.4/bin/mysqladmin -u runner password diesel
/opt/homebrew/opt/mariadb@11.4/bin/mysql -e "create database diesel_test; create database diesel_unit_test; grant all on \`diesel_%\`.* to 'runner'@'localhost';" -urunner
echo "DATABASE_URL=mysql://runner:diesel@localhost/diesel_test" >> $GITHUB_ENV
- name: Install postgres (Windows)
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/

## [Unreleased]

## [0.5.2] - 2024-11-26

* Fixed an issue around transaction cancellation that could lead to connection pools containing connections with dangling transactions

## [0.5.1] - 2024-11-01

* Add crate feature `pool` for extending connection pool implements through external crate
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "diesel-async"
version = "0.5.1"
version = "0.5.2"
authors = ["Georg Semmler <[email protected]>"]
edition = "2021"
autotests = false
Expand Down
249 changes: 87 additions & 162 deletions src/transaction_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ use diesel::QueryResult;
use scoped_futures::ScopedBoxFuture;
use std::borrow::Cow;
use std::num::NonZeroU32;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::AsyncConnection;
// TODO: refactor this to share more code with diesel
Expand Down Expand Up @@ -88,24 +90,31 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
/// in an error state.
#[doc(hidden)]
fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
match Self::transaction_manager_status_mut(conn).transaction_state() {
// all transactions are closed
// so we don't consider this connection broken
Ok(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => false,
// The transaction manager is in an error state
// Therefore we consider this connection broken
Err(_) => true,
// The transaction manager contains a open transaction
// we do consider this connection broken
// if that transaction was not opened by `begin_test_transaction`
Ok(ValidTransactionManagerStatus {
in_transaction: Some(s),
..
}) => !s.test_transaction,
}
check_broken_transaction_state(conn)
}
}

fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
where
Conn: AsyncConnection,
{
match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
// all transactions are closed
// so we don't consider this connection broken
Ok(ValidTransactionManagerStatus {
in_transaction: None,
..
}) => false,
// The transaction manager is in an error state
// Therefore we consider this connection broken
Err(_) => true,
// The transaction manager contains a open transaction
// we do consider this connection broken
// if that transaction was not opened by `begin_test_transaction`
Ok(ValidTransactionManagerStatus {
in_transaction: Some(s),
..
}) => !s.test_transaction,
}
}

Expand All @@ -114,147 +123,23 @@ pub trait TransactionManager<Conn: AsyncConnection>: Send {
#[derive(Default, Debug)]
pub struct AnsiTransactionManager {
pub(crate) status: TransactionManagerStatus,
// this boolean flag tracks whether we are currently in the process
// of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
// if we ever encounter a situation where this flag is set
// while the connection is returned to a pool
// that means the connection is broken as someone dropped the
// transaction future while these commands where executed
// and we cannot know the connection state anymore
//
// We ensure this by wrapping all calls to `.await`
// into `AnsiTransactionManager::critical_transaction_block`
// below
//
// See https://github.com/weiznich/diesel_async/issues/198 for
// details
pub(crate) is_broken: Arc<AtomicBool>,
}

// /// Status of the transaction manager
// #[derive(Debug)]
// pub enum TransactionManagerStatus {
// /// Valid status, the manager can run operations
// Valid(ValidTransactionManagerStatus),
// /// Error status, probably following a broken connection. The manager will no longer run operations
// InError,
// }

// impl Default for TransactionManagerStatus {
// fn default() -> Self {
// TransactionManagerStatus::Valid(ValidTransactionManagerStatus::default())
// }
// }

// impl TransactionManagerStatus {
// /// Returns the transaction depth if the transaction manager's status is valid, or returns
// /// [`Error::BrokenTransactionManager`] if the transaction manager is in error.
// pub fn transaction_depth(&self) -> QueryResult<Option<NonZeroU32>> {
// match self {
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status.transaction_depth()),
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
// }
// }

// /// If in transaction and transaction manager is not broken, registers that the
// /// connection can not be used anymore until top-level transaction is rolled back
// pub(crate) fn set_top_level_transaction_requires_rollback(&mut self) {
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
// in_transaction:
// Some(InTransactionStatus {
// top_level_transaction_requires_rollback,
// ..
// }),
// }) = self
// {
// *top_level_transaction_requires_rollback = true;
// }
// }

// /// Sets the transaction manager status to InError
// ///
// /// Subsequent attempts to use transaction-related features will result in a
// /// [`Error::BrokenTransactionManager`] error
// pub fn set_in_error(&mut self) {
// *self = TransactionManagerStatus::InError
// }

// fn transaction_state(&mut self) -> QueryResult<&mut ValidTransactionManagerStatus> {
// match self {
// TransactionManagerStatus::Valid(valid_status) => Ok(valid_status),
// TransactionManagerStatus::InError => Err(Error::BrokenTransactionManager),
// }
// }

// pub(crate) fn set_test_transaction_flag(&mut self) {
// if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
// in_transaction: Some(s),
// }) = self
// {
// s.test_transaction = true;
// }
// }
// }

// /// Valid transaction status for the manager. Can return the current transaction depth
// #[allow(missing_copy_implementations)]
// #[derive(Debug, Default)]
// pub struct ValidTransactionManagerStatus {
// in_transaction: Option<InTransactionStatus>,
// }

// #[allow(missing_copy_implementations)]
// #[derive(Debug)]
// struct InTransactionStatus {
// transaction_depth: NonZeroU32,
// top_level_transaction_requires_rollback: bool,
// test_transaction: bool,
// }

// impl ValidTransactionManagerStatus {
// /// Return the current transaction depth
// ///
// /// This value is `None` if no current transaction is running
// /// otherwise the number of nested transactions is returned.
// pub fn transaction_depth(&self) -> Option<NonZeroU32> {
// self.in_transaction.as_ref().map(|it| it.transaction_depth)
// }

// /// Update the transaction depth by adding the value of the `transaction_depth_change` parameter if the `query` is
// /// `Ok(())`
// pub fn change_transaction_depth(
// &mut self,
// transaction_depth_change: TransactionDepthChange,
// ) -> QueryResult<()> {
// match (&mut self.in_transaction, transaction_depth_change) {
// (Some(in_transaction), TransactionDepthChange::IncreaseDepth) => {
// // Can be replaced with saturating_add directly on NonZeroU32 once
// // <https://github.com/rust-lang/rust/issues/84186> is stable
// in_transaction.transaction_depth =
// NonZeroU32::new(in_transaction.transaction_depth.get().saturating_add(1))
// .expect("nz + nz is always non-zero");
// Ok(())
// }
// (Some(in_transaction), TransactionDepthChange::DecreaseDepth) => {
// // This sets `transaction_depth` to `None` as soon as we reach zero
// match NonZeroU32::new(in_transaction.transaction_depth.get() - 1) {
// Some(depth) => in_transaction.transaction_depth = depth,
// None => self.in_transaction = None,
// }
// Ok(())
// }
// (None, TransactionDepthChange::IncreaseDepth) => {
// self.in_transaction = Some(InTransactionStatus {
// transaction_depth: NonZeroU32::new(1).expect("1 is non-zero"),
// top_level_transaction_requires_rollback: false,
// test_transaction: false,
// });
// Ok(())
// }
// (None, TransactionDepthChange::DecreaseDepth) => {
// // We screwed up something somewhere
// // we cannot decrease the transaction count if
// // we are not inside a transaction
// Err(Error::NotInTransaction)
// }
// }
// }
// }

// /// Represents a change to apply to the depth of a transaction
// #[derive(Debug, Clone, Copy)]
// pub enum TransactionDepthChange {
// /// Increase the depth of the transaction (corresponds to `BEGIN` or `SAVEPOINT`)
// IncreaseDepth,
// /// Decreases the depth of the transaction (corresponds to `COMMIT`/`RELEASE SAVEPOINT` or `ROLLBACK`)
// DecreaseDepth,
// }

impl AnsiTransactionManager {
fn get_transaction_state<Conn>(
conn: &mut Conn,
Expand All @@ -274,17 +159,38 @@ impl AnsiTransactionManager {
where
Conn: AsyncConnection<TransactionManager = Self>,
{
let is_broken = conn.transaction_state().is_broken.clone();
let state = Self::get_transaction_state(conn)?;
match state.transaction_depth() {
None => {
conn.batch_execute(sql).await?;
Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
Ok(())
}
Some(_depth) => Err(Error::AlreadyInTransaction),
}
}

// This function should be used to await any connection
// related future in our transaction manager implementation
//
// It takes care of tracking entering and exiting executing the future
// which in turn is used to determine if it's safe to still use
// the connection in the event of a canceled transaction execution
async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
where
F: std::future::Future,
{
let was_broken = is_broken.swap(true, Ordering::Relaxed);
debug_assert!(
!was_broken,
"Tried to execute a transaction SQL on transaction manager that was previously cancled"
);
let res = f.await;
is_broken.store(false, Ordering::Relaxed);
res
}
}

#[async_trait::async_trait]
Expand All @@ -308,7 +214,11 @@ where
.unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
conn.instrumentation()
.on_connection_event(InstrumentationEvent::begin_transaction(depth));
conn.batch_execute(&start_transaction_sql).await?;
Self::critical_transaction_block(
&conn.transaction_state().is_broken.clone(),
conn.batch_execute(&start_transaction_sql),
)
.await?;
Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;

Expand Down Expand Up @@ -344,7 +254,10 @@ where
conn.instrumentation()
.on_connection_event(InstrumentationEvent::rollback_transaction(depth));

match conn.batch_execute(&rollback_sql).await {
let is_broken = conn.transaction_state().is_broken.clone();

match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
{
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
Expand Down Expand Up @@ -429,7 +342,9 @@ where
conn.instrumentation()
.on_connection_event(InstrumentationEvent::commit_transaction(depth));

match conn.batch_execute(&commit_sql).await {
let is_broken = conn.transaction_state().is_broken.clone();

match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
Ok(()) => {
match Self::get_transaction_state(conn)?
.change_transaction_depth(TransactionDepthChange::DecreaseDepth)
Expand All @@ -453,7 +368,12 @@ where
..
}) = conn.transaction_state().status
{
match Self::rollback_transaction(conn).await {
match Self::critical_transaction_block(
&is_broken,
Self::rollback_transaction(conn),
)
.await
{
Ok(()) => {}
Err(rollback_error) => {
conn.transaction_state().status.set_in_error();
Expand All @@ -472,4 +392,9 @@ where
fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
&mut conn.transaction_state().status
}

fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
conn.transaction_state().is_broken.load(Ordering::Relaxed)
|| check_broken_transaction_state(conn)
}
}

0 comments on commit e3beac6

Please sign in to comment.