Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support connection parameters, execute non-prepared statements #26

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions jaydebeapi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
# it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
#
# JayDeBeApi is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
#
# You should have received a copy of the GNU Lesser General Public
# License along with JayDeBeApi. If not, see
# <http://www.gnu.org/licenses/>.
Expand Down Expand Up @@ -83,7 +83,7 @@ def _handle_sql_exception_jython():
exc_type = InterfaceError
reraise(exc_type, exc_info[1], exc_info[2])

def _jdbc_connect_jython(jclassname, jars, libs, *args):
def _jdbc_connect_jython(jclassname, jars, libs, props, *args):
if _jdbc_name_to_const is None:
from java.sql import Types
types = Types
Expand Down Expand Up @@ -146,8 +146,8 @@ def _handle_sql_exception_jpype():
else:
exc_type = InterfaceError
reraise(exc_type, exc_info[1], exc_info[2])
def _jdbc_connect_jpype(jclassname, jars, libs, *driver_args):

def _jdbc_connect_jpype(jclassname, jars, libs, props, *driver_args):
import jpype
if not jpype.isJVMStarted():
args = []
Expand Down Expand Up @@ -180,6 +180,13 @@ def _java_array_byte(data):
return jpype.JArray(jpype.JByte, 1)(data)
# register driver for DriverManager
jpype.JClass(jclassname)

if props is not None:
jprops = jpype.java.util.Properties()
for k, v in props.iteritems():
jprops.put(k, v)
return jpype.java.sql.DriverManager.getConnection(driver_args[0], jprops)

return jpype.java.sql.DriverManager.getConnection(*driver_args)

def _get_classpath():
Expand Down Expand Up @@ -330,7 +337,7 @@ def TimestampFromTicks(ticks):
return apply(Timestamp, time.localtime(ticks)[:6])

# DB-API 2.0 Module Interface connect constructor
def connect(jclassname, driver_args, jars=None, libs=None):
def connect(jclassname, driver_args, jars=None, libs=None, props=None):
"""Open a connection to a database using a JDBC driver and return
a Connection instance.

Expand All @@ -356,7 +363,7 @@ def connect(jclassname, driver_args, jars=None, libs=None):
libs = [ libs ]
else:
libs = []
jconn = _jdbc_connect(jclassname, jars, libs, *driver_args)
jconn = _jdbc_connect(jclassname, jars, libs, props, *driver_args)
return Connection(jconn, _converters)

# DB-API 2.0 Connection Object
Expand Down Expand Up @@ -470,15 +477,25 @@ def _set_stmt_parms(self, prep_stmt, parameters):
def execute(self, operation, parameters=None):
if self._connection._closed:
raise Error()
if not parameters:
parameters = ()

self._close_last()
self._prep = self._connection.jconn.prepareStatement(operation)
self._set_stmt_parms(self._prep, parameters)
try:
is_rs = self._prep.execute()
except:
_handle_sql_exception()

if parameters == None:
self._prep = self._connection.jconn.createStatement()

try:
is_rs = self._prep.execute(operation)
except:
_handle_sql_exception()
else:
self._prep = self._connection.jconn.prepareStatement(operation)
self._set_stmt_parms(self._prep, parameters)

try:
is_rs = self._prep.execute()
except:
_handle_sql_exception()

if is_rs:
self._rs = self._prep.getResultSet()
self._meta = self._rs.getMetaData()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.Statement;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
Expand All @@ -25,23 +26,30 @@ public final void mockExceptionOnRollback(String className, String exceptionMess

public final void mockExceptionOnExecute(String className, String exceptionMessage) throws SQLException {
PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class);
Statement mockStatement = Mockito.mock(Statement.class);
Throwable exception = createException(className, exceptionMessage);
Mockito.when(mockPreparedStatement.execute()).thenThrow(exception);
Mockito.when(mockStatement.execute(Mockito.anyString())).thenThrow(exception);
Mockito.when(this.prepareStatement(Mockito.anyString())).thenReturn(mockPreparedStatement);
Mockito.when(this.createStatement()).thenReturn(mockStatement);
}

public final void mockType(String sqlTypesName) throws SQLException {
PreparedStatement mockPreparedStatement = Mockito.mock(PreparedStatement.class);
Statement mockStatement = Mockito.mock(Statement.class);
Mockito.when(mockPreparedStatement.execute()).thenReturn(true);
Mockito.when(mockStatement.execute(Mockito.anyString())).thenReturn(true);
mockResultSet = Mockito.mock(ResultSet.class, "ResultSet(for type " + sqlTypesName + ")");
Mockito.when(mockPreparedStatement.getResultSet()).thenReturn(mockResultSet);
Mockito.when(mockStatement.getResultSet()).thenReturn(mockResultSet);
Mockito.when(mockResultSet.next()).thenReturn(true);
ResultSetMetaData mockMetaData = Mockito.mock(ResultSetMetaData.class);
Mockito.when(mockResultSet.getMetaData()).thenReturn(mockMetaData);
Mockito.when(mockMetaData.getColumnCount()).thenReturn(1);
int sqlTypeCode = extractTypeCodeForName(sqlTypesName);
Mockito.when(mockMetaData.getColumnType(1)).thenReturn(sqlTypeCode);
Mockito.when(this.prepareStatement(Mockito.anyString())).thenReturn(mockPreparedStatement);
Mockito.when(this.createStatement()).thenReturn(mockStatement);
}

public final ResultSet verifyResultSet() {
Expand Down
10 changes: 10 additions & 0 deletions test/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# <http://www.gnu.org/licenses/>.

import jaydebeapi
from jaydebeapi import OperationalError

import os
import sys
Expand Down Expand Up @@ -207,6 +208,15 @@ def test_execute_different_rowcounts(self):
cursor.execute("select * from ACCOUNT")
self.assertEqual(cursor.rowcount, -1)

def test_sql_exception_on_execute(self):
cursor = self.conn.cursor()
try:
cursor.execute("dummy stmt")
except jaydebeapi.DatabaseError as e:
self.assertEquals(str(e).split(" ")[0], "java.sql.SQLException:")
except self.conn.OperationalError as e:
self.assertEquals("syntax" in str(e), True)

class SqliteTestBase(IntegrationTestBase):

def setUpSql(self):
Expand Down
9 changes: 9 additions & 0 deletions test/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ def test_sql_exception_on_execute(self):
except jaydebeapi.DatabaseError as e:
self.assertEquals(str(e), "java.sql.SQLException: expected")

def test_sql_exception_on_parameter_execute(self):
self.conn.jconn.mockExceptionOnExecute("java.sql.SQLException", "expected")
cursor = self.conn.cursor()
try:
cursor.execute("dummy stmt", (18,))
fail("expected exception")
except jaydebeapi.DatabaseError as e:
self.assertEquals(str(e), "java.sql.SQLException: expected")

def test_runtime_exception_on_execute(self):
self.conn.jconn.mockExceptionOnExecute("java.lang.RuntimeException", "expected")
cursor = self.conn.cursor()
Expand Down