Skip to content

Commit

Permalink
Improve Cursor.copy() (#538)
Browse files Browse the repository at this point in the history
  • Loading branch information
sitingren authored Nov 21, 2023
1 parent 2dca694 commit 26e3d08
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
10 changes: 10 additions & 0 deletions vertica_python/tests/integration_tests/test_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,16 @@ def test_cmd_after_rejected_copy_data(self):

self.assertListOfListsEqual(res, [[1]])

def test_copy_multiple_statements(self):
with self._connect() as conn:
cur = conn.cursor()
cur.copy("COPY {0} (a, b) FROM STDIN DELIMITER ','; SELECT 5".format(self._table),
"1,foo\n2,bar")
self.assertListOfListsEqual(cur.fetchall(), [])
self.assertTrue(cur.nextset())
self.assertListOfListsEqual(cur.fetchall(), [[5]])
self.assertFalse(cur.nextset())

def test_with_conn(self):
with self._connect() as conn:
cur = conn.cursor()
Expand Down
20 changes: 14 additions & 6 deletions vertica_python/vertica/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,10 @@ def nextset(self):
# result of a DDL/transaction
self.rowcount = -1
return True
elif isinstance(self._message, messages.CopyInResponse):
raise errors.MessageError(
'Unexpected nextset() state after END_OF_RESULT_RESPONSES: {self._message}\n'
'HINT: Do you pass multiple COPY statements into Cursor.copy()?')
elif isinstance(self._message, messages.ErrorResponse):
raise errors.QueryError.from_error_response(self._message, self.operation)
else:
Expand Down Expand Up @@ -458,6 +462,7 @@ def copy(self, sql, data, **kwargs):
"""
sql = as_text(sql)
self.operation = sql

if self.closed():
raise errors.InterfaceError('Cursor is closed')
Expand All @@ -473,13 +478,11 @@ def copy(self, sql, data, **kwargs):
else:
raise TypeError("Not valid type of data {0}".format(type(data)))

# TODO: check sql is a valid `COPY FROM STDIN` SQL statement

self._logger.info(u'Execute COPY statement: [{}]'.format(sql))
# Execute a `COPY FROM STDIN` SQL statement
self.connection.write(messages.Query(sql))

buffer_size = kwargs['buffer_size'] if 'buffer_size' in kwargs else DEFAULT_BUFFER_SIZE
self.buffer_size = kwargs.get('buffer_size', DEFAULT_BUFFER_SIZE)

while True:
message = self.connection.read_message()
Expand All @@ -490,10 +493,10 @@ def copy(self, sql, data, **kwargs):
elif isinstance(message, messages.ReadyForQuery):
break
elif isinstance(message, messages.CommandComplete):
pass
break
elif isinstance(message, messages.CopyInResponse):
try:
self._send_copy_data(stream, buffer_size)
self._send_copy_data(stream, self.buffer_size)
except Exception as e:
# COPY termination: report the cause of failure to the backend
self.connection.write(messages.CopyFail(str(e)))
Expand All @@ -503,8 +506,13 @@ def copy(self, sql, data, **kwargs):

# Successful termination for COPY
self.connection.write(messages.CopyDone())
elif isinstance(message, messages.RowDescription):
raise errors.MessageError(f'Unexpected message: {message}\n'
f'HINT: Query for Cursor.copy() should be a `COPY FROM STDIN` SQL statement.'
' `COPY FROM LOCAL` should be executed with Cursor.execute().\n'
f'SQL: {sql}')
else:
raise errors.MessageError('Unexpected message: {0}'.format(message))
raise errors.MessageError(f'Unexpected message: {message}')

if self.error is not None:
raise self.error
Expand Down

0 comments on commit 26e3d08

Please sign in to comment.