Skip to content

Commit

Permalink
implement CalcCursor to support Spark calculations (fix #493)
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Jan 6, 2024
1 parent 39ad5a1 commit 8fe1cde
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 261 deletions.
46 changes: 7 additions & 39 deletions pyathena/arrow/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from concurrent.futures import Future
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union, cast

from pyathena import ProgrammingError
from pyathena.arrow.converter import (
Expand All @@ -14,55 +14,23 @@
from pyathena.arrow.result_set import AthenaArrowResultSet
from pyathena.async_cursor import AsyncCursor
from pyathena.common import CursorIterator
from pyathena.converter import Converter
from pyathena.formatter import Formatter
from pyathena.model import AthenaCompression, AthenaFileFormat
from pyathena.util import RetryConfig

if TYPE_CHECKING:
from pyathena.connection import Connection
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution

_logger = logging.getLogger(__name__) # type: ignore


class AsyncArrowCursor(AsyncCursor):
def __init__(
self,
connection: "Connection",
converter: Converter,
formatter: Formatter,
retry_config: RetryConfig,
s3_staging_dir: Optional[str] = None,
schema_name: Optional[str] = None,
catalog_name: Optional[str] = None,
work_group: Optional[str] = None,
poll_interval: float = 1,
encryption_option: Optional[str] = None,
kms_key: Optional[str] = None,
kill_on_interrupt: bool = True,
max_workers: int = (cpu_count() or 1) * 5,
arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE,
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
**kwargs,
) -> None:
super(AsyncArrowCursor, self).__init__(
connection=connection,
converter=converter,
formatter=formatter,
retry_config=retry_config,
s3_staging_dir=s3_staging_dir,
schema_name=schema_name,
catalog_name=catalog_name,
work_group=work_group,
poll_interval=poll_interval,
encryption_option=encryption_option,
kms_key=kms_key,
kill_on_interrupt=kill_on_interrupt,
max_workers=max_workers,
arraysize=arraysize,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
**kwargs,
)
self._unload = unload

Expand Down Expand Up @@ -93,7 +61,7 @@ def _collect_result_set(
) -> AthenaArrowResultSet:
if kwargs is None:
kwargs = dict()
query_execution = self._poll(query_id)
query_execution = cast(AthenaQueryExecution, self._poll(query_id))
return AthenaArrowResultSet(
connection=self._connection,
converter=self._converter,
Expand All @@ -111,8 +79,8 @@ def execute(
parameters: Optional[Dict[str, Any]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
**kwargs,
Expand Down
45 changes: 5 additions & 40 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,23 @@
)
from pyathena.arrow.result_set import AthenaArrowResultSet
from pyathena.common import BaseCursor, CursorIterator
from pyathena.converter import Converter
from pyathena.error import OperationalError, ProgrammingError
from pyathena.formatter import Formatter
from pyathena.model import AthenaCompression, AthenaFileFormat, AthenaQueryExecution
from pyathena.result_set import WithResultSet
from pyathena.util import RetryConfig

if TYPE_CHECKING:
from pyarrow import Table

from pyathena.connection import Connection

_logger = logging.getLogger(__name__) # type: ignore


class ArrowCursor(BaseCursor, CursorIterator, WithResultSet):
def __init__(
self,
connection: "Connection",
converter: Converter,
formatter: Formatter,
retry_config: RetryConfig,
s3_staging_dir: Optional[str] = None,
schema_name: Optional[str] = None,
catalog_name: Optional[str] = None,
work_group: Optional[str] = None,
poll_interval: float = 1,
encryption_option: Optional[str] = None,
kms_key: Optional[str] = None,
kill_on_interrupt: bool = True,
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
**kwargs,
) -> None:
super(ArrowCursor, self).__init__(
connection=connection,
converter=converter,
formatter=formatter,
retry_config=retry_config,
s3_staging_dir=s3_staging_dir,
schema_name=schema_name,
catalog_name=catalog_name,
work_group=work_group,
poll_interval=poll_interval,
encryption_option=encryption_option,
kms_key=kms_key,
kill_on_interrupt=kill_on_interrupt,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
**kwargs,
)
super(ArrowCursor, self).__init__(**kwargs)
self._unload = unload
self._query_id: Optional[str] = None
self._result_set: Optional[AthenaArrowResultSet] = None
Expand Down Expand Up @@ -115,8 +80,8 @@ def execute(
parameters: Optional[Dict[str, Any]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
**kwargs,
Expand All @@ -143,7 +108,7 @@ def execute(
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
)
query_execution = self._poll(self.query_id)
query_execution = cast(AthenaQueryExecution, self._poll(self.query_id))
if query_execution.state == AthenaQueryExecution.STATE_SUCCEEDED:
self.result_set = AthenaArrowResultSet(
connection=self._connection,
Expand All @@ -160,7 +125,7 @@ def execute(
return self

def executemany(
self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]]
self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs
) -> None:
for parameters in seq_of_parameters:
self.execute(operation, parameters)
Expand Down
51 changes: 9 additions & 42 deletions pyathena/async_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,59 +5,25 @@
from concurrent.futures import Future
from concurrent.futures.thread import ThreadPoolExecutor
from multiprocessing import cpu_count
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

from pyathena.common import CursorIterator
from pyathena.converter import Converter
from pyathena.cursor import BaseCursor
from pyathena.error import NotSupportedError, ProgrammingError
from pyathena.formatter import Formatter
from pyathena.model import AthenaQueryExecution
from pyathena.result_set import AthenaDictResultSet, AthenaResultSet
from pyathena.util import RetryConfig

if TYPE_CHECKING:
from pyathena.connection import Connection

_logger = logging.getLogger(__name__) # type: ignore


class AsyncCursor(BaseCursor):
def __init__(
self,
connection: "Connection",
converter: Converter,
formatter: Formatter,
retry_config: RetryConfig,
s3_staging_dir: Optional[str] = None,
schema_name: Optional[str] = None,
catalog_name: Optional[str] = None,
work_group: Optional[str] = None,
poll_interval: float = 1,
encryption_option: Optional[str] = None,
kms_key: Optional[str] = None,
kill_on_interrupt: bool = True,
max_workers: int = (cpu_count() or 1) * 5,
arraysize: int = CursorIterator.DEFAULT_FETCH_SIZE,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
**kwargs,
) -> None:
super(AsyncCursor, self).__init__(
connection=connection,
converter=converter,
formatter=formatter,
retry_config=retry_config,
s3_staging_dir=s3_staging_dir,
schema_name=schema_name,
catalog_name=catalog_name,
work_group=work_group,
poll_interval=poll_interval,
encryption_option=encryption_option,
kms_key=kms_key,
kill_on_interrupt=kill_on_interrupt,
result_reuse_enable=result_reuse_enable,
result_reuse_minutes=result_reuse_minutes,
)
super(AsyncCursor, self).__init__(**kwargs)
self._max_workers = max_workers
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._arraysize = arraysize
Expand Down Expand Up @@ -94,10 +60,10 @@ def query_execution(self, query_id: str) -> "Future[AthenaQueryExecution]":
return self._executor.submit(self._get_query_execution, query_id)

def poll(self, query_id: str) -> "Future[AthenaQueryExecution]":
return self._executor.submit(self._poll, query_id)
return cast(Future[AthenaQueryExecution], self._executor.submit(self._poll, query_id))

def _collect_result_set(self, query_id: str) -> AthenaResultSet:
query_execution = self._poll(query_id)
query_execution = cast(AthenaQueryExecution, self._poll(query_id))
return self._result_set_class(
connection=self._connection,
converter=self._converter,
Expand All @@ -112,10 +78,11 @@ def execute(
parameters: Optional[Dict[str, Any]] = None,
work_group: Optional[str] = None,
s3_staging_dir: Optional[str] = None,
cache_size: int = 0,
cache_expiration_time: int = 0,
cache_size: Optional[int] = 0,
cache_expiration_time: Optional[int] = 0,
result_reuse_enable: Optional[bool] = None,
result_reuse_minutes: Optional[int] = None,
**kwargs,
) -> Tuple[str, "Future[Union[AthenaResultSet, Any]]"]:
query_id = self._execute(
operation,
Expand All @@ -130,7 +97,7 @@ def execute(
return query_id, self._executor.submit(self._collect_result_set, query_id)

def executemany(
self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]]
self, operation: str, seq_of_parameters: List[Optional[Dict[str, Any]]], **kwargs
) -> None:
raise NotSupportedError

Expand Down
Loading

0 comments on commit 8fe1cde

Please sign in to comment.