Skip to content

Commit

Permalink
Fix ServerSide query regression introduced in 3.16.0
Browse files Browse the repository at this point in the history
Fixes #2899
  • Loading branch information
coleifer committed May 22, 2024
1 parent e77b994 commit c6f4c4d
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
44 changes: 32 additions & 12 deletions playhouse/postgres_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from peewee import Node
from peewee import NodeList
from peewee import __deprecated__
from peewee import __exception_wrapper__

try:
from psycopg2cffi import compat
Expand Down Expand Up @@ -394,6 +395,13 @@ def __init__(self, cursor, array_size=None):
self.exhausted = False
self.iterable = self.row_gen()

def __del__(self):
if self.cursor and not self.cursor.closed:
try:
self.cursor.close()
except Exception:
pass

@property
def description(self):
return self.cursor.description
Expand All @@ -402,12 +410,15 @@ def close(self):
self.cursor.close()

def row_gen(self):
while True:
rows = self.cursor.fetchmany(self.array_size)
if not rows:
return
for row in rows:
yield row
try:
while True:
rows = self.cursor.fetchmany(self.array_size)
if not rows:
return
for row in rows:
yield row
finally:
self.close()

def fetchone(self):
if self.exhausted:
Expand Down Expand Up @@ -443,10 +454,9 @@ def _execute(self, database):
def ServerSide(query, database=None, array_size=None):
if database is None:
database = query._database
with database.transaction():
server_side_query = ServerSideQuery(query, array_size=array_size)
for row in server_side_query:
yield row
server_side_query = ServerSideQuery(query, array_size=array_size)
for row in server_side_query:
yield row


class _empty_object(object):
Expand Down Expand Up @@ -477,7 +487,8 @@ def cursor(self, commit=None, named_cursor=None):
else:
raise InterfaceError('Error, database connection not opened.')
if named_cursor:
curs = self._state.conn.cursor(name=str(uuid.uuid1()))
curs = self._state.conn.cursor(name=str(uuid.uuid1()),
withhold=True)
return curs
return self._state.conn.cursor()

Expand All @@ -489,7 +500,16 @@ def execute(self, query, commit=None, named_cursor=False, array_size=None,
sql, params = ctx.sql(query).query()
named_cursor = named_cursor or (self._server_side_cursors and
sql[:6].lower() == 'select')
cursor = self.execute_sql(sql, params)
cursor = self.execute_sql(sql, params, named_cursor=named_cursor)
if named_cursor:
cursor = FetchManyCursor(cursor, array_size)
return cursor

def execute_sql(self, sql, params=None, commit=None, named_cursor=None):
if commit is not None:
__deprecated__('"commit" has been deprecated and is a no-op.')
logger.debug((sql, params))
with __exception_wrapper__:
cursor = self.cursor(named_cursor=named_cursor)
cursor.execute(sql, params or ())
return cursor
10 changes: 10 additions & 0 deletions tests/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,16 @@ def test_server_side_cursor(self):
ss_query = ServerSide(query.where(SQL('1 = 0')))
self.assertEqual(list(ss_query), [])

def test_lower_level_apis(self):
query = Register.select(Register.value).order_by(Register.value)
ssq = ServerSideQuery(query, array_size=10)
curs_wrapper = ssq._execute(self.database)
curs = curs_wrapper.cursor
self.assertTrue(isinstance(curs, FetchManyCursor))
self.assertEqual(curs.fetchone(), (0,))
self.assertEqual(curs.fetchone(), (1,))
curs.close()


class KX(TestModel):
key = CharField(unique=True)
Expand Down

0 comments on commit c6f4c4d

Please sign in to comment.