diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index fb78ccfa6..4ff768345 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -46,6 +46,7 @@ import software.amazon.jdbc.plugin.staledns.AuroraStaleDnsPlugin; import software.amazon.jdbc.plugin.strategy.fastestresponse.FastestResponseStrategyPlugin; import software.amazon.jdbc.profile.ConfigurationProfile; +import software.amazon.jdbc.util.AsynchronousMethodsHelper; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlMethodAnalyzer; import software.amazon.jdbc.util.WrapperUtils; @@ -94,8 +95,6 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { private static final String NOTIFY_CONNECTION_CHANGED_METHOD = "notifyConnectionChanged"; private static final String NOTIFY_NODE_LIST_CHANGED_METHOD = "notifyNodeListChanged"; private static final SqlMethodAnalyzer sqlMethodAnalyzer = new SqlMethodAnalyzer(); - private final ReentrantLock lock = new ReentrantLock(); - protected Properties props = new Properties(); protected List plugins; protected final @NonNull ConnectionProvider defaultConnProvider; @@ -151,14 +150,6 @@ public ConnectionPluginManager( this.telemetryFactory = telemetryFactory; } - public void lock() { - lock.lock(); - } - - public void unlock() { - lock.unlock(); - } - /** * Initialize a chain of {@link ConnectionPlugin} using their corresponding {@link * ConnectionPluginFactory}. If {@code PropertyDefinition.PLUGINS} is provided by the user, @@ -308,7 +299,13 @@ public T execute( final Object[] jdbcMethodArgs) throws E { - final Connection conn = WrapperUtils.getConnectionFromSqlObject(methodInvokeOn); + final Connection conn; + if (AsynchronousMethodsHelper.ASYNCHRONOUS_METHODS.contains(methodName)) { + conn = this.pluginService.getCurrentConnection(); + } else { + conn = WrapperUtils.getConnectionFromSqlObject(methodInvokeOn); + } + if (conn != null && conn != this.pluginService.getCurrentConnection() && !sqlMethodAnalyzer.isMethodClosingSqlObject(methodName)) { final SQLException e = diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/AsynchronousMethodsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/util/AsynchronousMethodsHelper.java new file mode 100644 index 000000000..ad1757d7c --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/AsynchronousMethodsHelper.java @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util; + +import java.util.Collections; +import java.util.List; + +public class AsynchronousMethodsHelper { + public static final List ASYNCHRONOUS_METHODS = Collections.singletonList( + "Statement.cancel" + ); +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java index 2f5b717eb..0a95a0d15 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/WrapperUtils.java @@ -50,6 +50,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.locks.ReentrantLock; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.ConnectionPluginManager; import software.amazon.jdbc.JdbcCallable; @@ -84,6 +85,8 @@ public class WrapperUtils { private static final ConcurrentMap, Boolean> isJdbcInterfaceCache = new ConcurrentHashMap<>(); + private static final ReentrantLock lock = new ReentrantLock(); + private static final Map, Class> availableWrappers = new HashMap, Class>() { { @@ -182,7 +185,11 @@ public static T executeWithPlugins( final JdbcCallable jdbcMethodFunc, final Object... jdbcMethodArgs) { - pluginManager.lock(); + boolean locked = false; + if (!AsynchronousMethodsHelper.ASYNCHRONOUS_METHODS.contains(methodName)) { + lock.lock(); + locked = true; + } TelemetryFactory telemetryFactory = pluginManager.getTelemetryFactory(); TelemetryContext context = null; @@ -208,7 +215,9 @@ public static T executeWithPlugins( throw new RuntimeException(e); } } finally { - pluginManager.unlock(); + if (locked) { + lock.unlock(); + } if (context != null) { context.closeContext(); } @@ -225,7 +234,11 @@ public static T executeWithPlugins( final Object... jdbcMethodArgs) throws E { - pluginManager.lock(); + boolean locked = false; + if (!AsynchronousMethodsHelper.ASYNCHRONOUS_METHODS.contains(methodName)) { + lock.lock(); + locked = true; + } TelemetryFactory telemetryFactory = pluginManager.getTelemetryFactory(); TelemetryContext context = null; @@ -251,7 +264,9 @@ public static T executeWithPlugins( } } finally { - pluginManager.unlock(); + if (locked) { + lock.unlock(); + } if (context != null) { context.closeContext(); } diff --git a/wrapper/src/test/java/integration/container/tests/MysqlTests.java b/wrapper/src/test/java/integration/container/tests/MysqlTests.java new file mode 100644 index 000000000..7ccb9fe32 --- /dev/null +++ b/wrapper/src/test/java/integration/container/tests/MysqlTests.java @@ -0,0 +1,86 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.container.tests; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import integration.DatabaseEngine; +import integration.container.ConnectionStringHelper; +import integration.container.TestDriver; +import integration.container.TestDriverProvider; +import integration.container.TestEnvironment; +import integration.container.condition.EnableOnDatabaseEngine; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.logging.Logger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; + +@EnableOnDatabaseEngine({DatabaseEngine.MYSQL}) +@ExtendWith(TestDriverProvider.class) +public class MysqlTests { + + private static final Logger LOGGER = Logger.getLogger(MysqlTests.class.getName()); + + @Test + void testCancelStatement(TestDriver testDriver) { + String url = + ConnectionStringHelper.getWrapperUrl( + testDriver, + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getInstances() + .get(0) + .getHost(), + TestEnvironment.getCurrent() + .getInfo() + .getDatabaseInfo() + .getInstances() + .get(0) + .getPort(), + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getDefaultDbName()); + LOGGER.finest("Connecting to " + url); + try (final Connection conn = DriverManager.getConnection(url)) { + Statement stmt = conn.createStatement(); + Thread thread = new Thread(() -> { + try { + Thread.sleep(1000); + stmt.cancel(); + } catch (SQLException | InterruptedException e) { + fail(e); + } + }); + + final long startTime = System.currentTimeMillis(); + thread.start(); + stmt.execute("select sleep(1000)"); + + try { + thread.join(); + assertTrue(System.currentTimeMillis() - startTime < 10000); + } catch (InterruptedException e) { + fail(e); + } + } catch (Exception e) { + fail(e); + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java b/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java index 1c5ac8503..b90a74351 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/WrapperUtilsTest.java @@ -58,15 +58,6 @@ void init() { final ReentrantLock testLock = new ReentrantLock(); closeable = MockitoAnnotations.openMocks(this); - doAnswer(invocation -> { - pluginManagerLock.lock(); - return null; - }).when(pluginManager).lock(); - doAnswer(invocation -> { - pluginManagerLock.unlock(); - return null; - }).when(pluginManager).unlock(); - doAnswer(invocation -> { boolean lockIsFree = testLock.tryLock(); if (!lockIsFree) {