Skip to content

Commit

Permalink
Merge pull request #23 from wherobots/max/query-cancel
Browse files Browse the repository at this point in the history
Fix support for query cancellation
  • Loading branch information
mpetazzoni authored Nov 8, 2024
2 parents b4dc1b7 + 7d59bfe commit 06c0e87
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 58 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# Vim
*.swp
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "wherobots-python-dbapi"
version = "0.9.0"
version = "0.9.1"
description = "Python DB-API driver for Wherobots DB"
authors = ["Maxime Petazzoni <[email protected]>"]
license = "Apache 2.0"
Expand Down Expand Up @@ -29,6 +29,7 @@ pytest = "^8.0.2"
black = "^24.2.0"
pre-commit = "^3.6.2"
conventional-pre-commit = "^3.1.0"
types-requests = "^2.32.0.20241016"
rich = "^13.7.1"

[build-system]
Expand Down
15 changes: 11 additions & 4 deletions tests/smoke.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# A simple smoke test for the DB driver.

import argparse
import concurrent.futures
import functools
import logging
import sys
Expand All @@ -11,6 +12,7 @@

from wherobots.db import connect, connect_direct
from wherobots.db.constants import DEFAULT_ENDPOINT
from wherobots.db.connection import Connection
from wherobots.db.region import Region
from wherobots.db.runtime import Runtime

Expand Down Expand Up @@ -85,8 +87,13 @@ def render(results: pandas.DataFrame):
table.add_row(*r)
Console().print(table)

def execute(conn: Connection, sql: str) -> pandas.DataFrame:
with conn.cursor() as cursor:
cursor.execute(sql)
return cursor.fetchall()

with conn_func() as conn:
for sql in args.sql:
with conn.cursor() as cursor:
cursor.execute(sql)
render(cursor.fetchall())
with concurrent.futures.ThreadPoolExecutor() as pool:
futures = [pool.submit(execute, conn, s) for s in args.sql]
for future in concurrent.futures.as_completed(futures):
render(future.result())
88 changes: 49 additions & 39 deletions wherobots/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Union

import cbor2
import pandas
import pyarrow
import websockets.exceptions
import websockets.protocol
Expand Down Expand Up @@ -74,19 +75,19 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def close(self):
def close(self) -> None:
self.__ws.close()

def commit(self):
def commit(self) -> None:
raise NotSupportedError

def rollback(self):
def rollback(self) -> None:
raise NotSupportedError

def cursor(self) -> Cursor:
return Cursor(self.__execute_sql, self.__cancel_query)

def __main_loop(self):
def __main_loop(self) -> None:
"""Main background loop listening for messages from the SQL session."""
logging.info("Starting background connection handling loop...")
while self.__ws.protocol.state < websockets.protocol.State.CLOSING:
Expand All @@ -101,7 +102,7 @@ def __main_loop(self):
except Exception as e:
logging.exception("Error handling message from SQL session", exc_info=e)

def __listen(self):
def __listen(self) -> None:
"""Waits for the next message from the SQL session and processes it.
The code in this method is purposefully defensive to avoid unexpected situations killing the thread.
Expand All @@ -120,61 +121,70 @@ def __listen(self):
)
return

if kind == EventKind.STATE_UPDATED:
# Incoming state transitions are handled here.
if kind == EventKind.STATE_UPDATED or kind == EventKind.EXECUTION_RESULT:
try:
query.state = ExecutionState[message["state"].upper()]
logging.info("Query %s is now %s.", execution_id, query.state)
except KeyError:
logging.warning("Invalid state update message for %s", execution_id)
return

# Incoming state transitions are handled here.
if query.state == ExecutionState.SUCCEEDED:
self.__request_results(execution_id)
# On a state_updated event telling us the query succeeded,
# ask for results.
if kind == EventKind.STATE_UPDATED:
self.__request_results(execution_id)
return

# Otherwise, process the results from the execution_result event.
results = message.get("results")
if not results or not isinstance(results, dict):
logging.warning("Got no results back from %s.", execution_id)
return

query.state = ExecutionState.COMPLETED
query.handler(self._handle_results(execution_id, results))
elif query.state == ExecutionState.CANCELLED:
logging.info("Query %s has been cancelled.", execution_id)
logging.info(
"Query %s has been cancelled; returning empty results.",
execution_id,
)
query.handler(pandas.DataFrame())
self.__queries.pop(execution_id)
elif query.state == ExecutionState.FAILED:
# Don't do anything here; the ERROR event is coming with more
# details.
pass

elif kind == EventKind.EXECUTION_RESULT:
results = message.get("results")
if not results or not isinstance(results, dict):
logging.warning("Got no results back from %s.", execution_id)
return

result_bytes = results.get("result_bytes")
result_format = results.get("format")
result_compression = results.get("compression")
logging.info(
"Received %d bytes of %s-compressed %s results from %s.",
len(result_bytes),
result_compression,
result_format,
execution_id,
)

query.state = ExecutionState.COMPLETED
if result_format == ResultsFormat.JSON:
query.handler(json.loads(result_bytes.decode("utf-8")))
elif result_format == ResultsFormat.ARROW:
buffer = pyarrow.py_buffer(result_bytes)
stream = pyarrow.input_stream(buffer, result_compression)
with pyarrow.ipc.open_stream(stream) as reader:
query.handler(reader.read_pandas())
else:
query.handler(
OperationalError(f"Unsupported results format {result_format}")
)
elif kind == EventKind.ERROR:
query.state = ExecutionState.FAILED
error = message.get("message")
query.handler(OperationalError(error))
else:
logging.warning("Received unknown %s event!", kind)

def _handle_results(self, execution_id: str, results: dict[str, Any]) -> Any:
result_bytes = results.get("result_bytes")
result_format = results.get("format")
result_compression = results.get("compression")
logging.info(
"Received %d bytes of %s-compressed %s results from %s.",
len(result_bytes),
result_compression,
result_format,
execution_id,
)

if result_format == ResultsFormat.JSON:
return json.loads(result_bytes.decode("utf-8"))
elif result_format == ResultsFormat.ARROW:
buffer = pyarrow.py_buffer(result_bytes)
stream = pyarrow.input_stream(buffer, result_compression)
with pyarrow.ipc.open_stream(stream) as reader:
return reader.read_pandas()
else:
return OperationalError(f"Unsupported results format {result_format}")

def __send(self, message: dict[str, Any]) -> None:
request = json.dumps(message)
logging.debug("Request: %s", request)
Expand Down
6 changes: 3 additions & 3 deletions wherobots/db/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ExecutionState(LowercaseStrEnum):
COMPLETED = auto()
"The driver has completed processing the query results."

def is_terminal_state(self):
def is_terminal_state(self) -> bool:
return self in (
ExecutionState.COMPLETED,
ExecutionState.CANCELLED,
Expand Down Expand Up @@ -97,7 +97,7 @@ class AppStatus(StrEnum):
DESTROY_FAILED = auto()
DESTROYED = auto()

def is_starting(self):
def is_starting(self) -> bool:
return self in (
AppStatus.PENDING,
AppStatus.PREPARING,
Expand All @@ -107,7 +107,7 @@ def is_starting(self):
AppStatus.INITIALIZING,
)

def is_terminal_state(self):
def is_terminal_state(self) -> bool:
return self in (
AppStatus.PREPARE_FAILED,
AppStatus.DEPLOY_FAILED,
Expand Down
22 changes: 12 additions & 10 deletions wherobots/db/cursor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import queue
from typing import Any, Optional, List, Tuple

from .errors import ProgrammingError, DatabaseError
from .errors import DatabaseError, ProgrammingError

_TYPE_MAP = {
"object": "STRING",
Expand All @@ -16,7 +16,7 @@

class Cursor:

def __init__(self, exec_fn, cancel_fn):
def __init__(self, exec_fn, cancel_fn) -> None:
self.__exec_fn = exec_fn
self.__cancel_fn = cancel_fn

Expand Down Expand Up @@ -72,7 +72,7 @@ def __get_results(self) -> Optional[List[Tuple[Any, ...]]]:

return self.__results

def execute(self, operation: str, parameters: dict[str, Any] = None):
def execute(self, operation: str, parameters: dict[str, Any] = None) -> None:
if self.__current_execution_id:
self.__cancel_fn(self.__current_execution_id)

Expand All @@ -84,38 +84,40 @@ def execute(self, operation: str, parameters: dict[str, Any] = None):
sql = operation.format(**(parameters or {}))
self.__current_execution_id = self.__exec_fn(sql, self.__on_execution_result)

def executemany(self, operation: str, seq_of_parameters: list[dict[str, Any]]):
def executemany(
self, operation: str, seq_of_parameters: list[dict[str, Any]]
) -> None:
raise NotImplementedError

def fetchone(self):
def fetchone(self) -> Any:
results = self.__get_results()[self.__current_row :]
if len(results) == 0:
return None
self.__current_row += 1
return results[0]

def fetchmany(self, size: int = None):
def fetchmany(self, size: int = None) -> list[Any]:
size = size or self.arraysize
results = self.__get_results()[self.__current_row : self.__current_row + size]
self.__current_row += size
return results

def fetchall(self):
def fetchall(self) -> list[Any]:
return self.__get_results()[self.__current_row :]

def close(self):
def close(self) -> None:
"""Close the cursor."""
if self.__results is None and self.__current_execution_id:
self.__cancel_fn(self.__current_execution_id)

def __iter__(self):
return self

def __next__(self):
def __next__(self) -> None:
raise StopIteration

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
self.close()

0 comments on commit 06c0e87

Please sign in to comment.