diff --git a/src/databricks/sql/auth/retry.py b/src/databricks/sql/auth/retry.py index 109b6949..56cd3412 100755 --- a/src/databricks/sql/auth/retry.py +++ b/src/databricks/sql/auth/retry.py @@ -32,7 +32,6 @@ class CommandType(Enum): CLOSE_SESSION = "CloseSession" CLOSE_OPERATION = "CloseOperation" GET_OPERATION_STATUS = "GetOperationStatus" - FETCH_RESULTS_INLINE_FETCH_NEXT = "FetchResultsInline_FETCH_NEXT" OTHER = "Other" @classmethod @@ -242,6 +241,14 @@ def command_type(self) -> Optional[CommandType]: def command_type(self, value: CommandType) -> None: self._command_type = value + @property + def is_retryable(self) -> bool: + return self._is_retryable + + @is_retryable.setter + def is_retryable(self, value: bool) -> None: + self._is_retryable = value + @property def delay_default(self) -> float: """Time in seconds the connector will wait between requests polling a GetOperationStatus Request @@ -363,11 +370,8 @@ def should_retry(self, method: str, status_code: int) -> Tuple[bool, str]: if status_code == 501: raise NonRecoverableNetworkError("Received code 501 from server.") - if self.command_type == CommandType.FETCH_RESULTS_INLINE_FETCH_NEXT: - return ( - False, - "FetchResults in INLINE mode with FETCH_NEXT orientation are not idempotent and is not retried", - ) + if self.is_retryable == False: + return False, "Request is not retryable" # Request failed and this method is not retryable. We only retry POST requests. if not self._is_method_retryable(method): diff --git a/src/databricks/sql/auth/thrift_http_client.py b/src/databricks/sql/auth/thrift_http_client.py index 6273ab28..6381ac61 100644 --- a/src/databricks/sql/auth/thrift_http_client.py +++ b/src/databricks/sql/auth/thrift_http_client.py @@ -216,3 +216,12 @@ def set_retry_command_type(self, value: CommandType): logger.warning( "DatabricksRetryPolicy is currently bypassed. The CommandType cannot be set." ) + + def set_is_retryable(self, retryable: bool): + """Pass the provided retryable flag to the retry policy""" + if isinstance(self.retry_policy, DatabricksRetryPolicy): + self.retry_policy.is_retryable = retryable + else: + logger.warning( + "DatabricksRetryPolicy is currently bypassed. The is_retryable flag cannot be set." + ) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 8ea81e12..341a6ebe 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -808,6 +808,7 @@ def execute( self.thrift_backend, self.buffer_size_bytes, self.arraysize, + self.connection.use_cloud_fetch, ) if execute_response.is_staging_operation: @@ -1202,6 +1203,7 @@ def __init__( thrift_backend: ThriftBackend, result_buffer_size_bytes: int = DEFAULT_RESULT_BUFFER_SIZE_BYTES, arraysize: int = 10000, + use_cloud_fetch: bool = True, ): """ A ResultSet manages the results of a single command. @@ -1223,6 +1225,7 @@ def __init__( self.description = execute_response.description self._arrow_schema_bytes = execute_response.arrow_schema_bytes self._next_row_index = 0 + self.use_cloud_fetch = use_cloud_fetch if execute_response.arrow_queue: # In this case the server has taken the fast path and returned an initial batch of @@ -1250,6 +1253,7 @@ def _fill_results_buffer(self): lz4_compressed=self.lz4_compressed, arrow_schema_bytes=self._arrow_schema_bytes, description=self.description, + use_cloud_fetch=self.use_cloud_fetch, ) self.results = results self.has_more_rows = has_more_rows diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 70b29d32..ab4da315 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -321,7 +321,7 @@ def _handle_request_error(self, error_info, attempt, elapsed): # FUTURE: Consider moving to https://github.com/litl/backoff or # https://github.com/jd/tenacity for retry logic. - def make_request(self, method, request): + def make_request(self, method, request, retryable=True): """Execute given request, attempting retries when 1. Receiving HTTP 429/503 from server 2. OSError is raised during a GetOperationStatus @@ -374,20 +374,9 @@ def attempt_request(attempt): # These three lines are no-ops if the v3 retry policy is not in use if self.enable_v3_retries: - # Not to retry when FetchResults in INLINE mode when it has orientation as FETCH_NEXT as it is not idempotent - if ( - this_method_name == "FetchResults" - and self._use_cloud_fetch == False - ): - this_method_name += ( - "Inline_" - + ttypes.TFetchOrientation._VALUES_TO_NAMES[ - request.orientation - ] - ) - this_command_type = CommandType.get(this_method_name) self._transport.set_retry_command_type(this_command_type) + self._transport.set_is_retryable(retryable) self._transport.startRetryTimer() response = method(request) @@ -898,8 +887,6 @@ def execute_command( ): assert session_handle is not None - self._use_cloud_fetch = use_cloud_fetch - spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -1042,6 +1029,7 @@ def fetch_results( lz4_compressed, arrow_schema_bytes, description, + use_cloud_fetch=True, ): assert op_handle is not None @@ -1058,7 +1046,8 @@ def fetch_results( includeResultSetMetadata=True, ) - resp = self.make_request(self._client.FetchResults, req) + # Fetch results in Inline mode with FETCH_NEXT orientation are not idempotent and hence not retried + resp = self.make_request(self._client.FetchResults, req, use_cloud_fetch) if resp.results.startRowOffset > expected_row_start_offset: raise DataError( "fetch_results failed due to inconsistency in the state between the client and the server. Expected results to start from {} but they instead start at {}, some result batches must have been skipped".format( diff --git a/tests/unit/test_retry.py b/tests/unit/test_retry.py index 9e52df72..be8e1ecc 100644 --- a/tests/unit/test_retry.py +++ b/tests/unit/test_retry.py @@ -87,6 +87,6 @@ def test_sleep__retry_after_present(self, t_mock, retry_policy, error_history): def test_not_retryable__fetch_results_orientation_fetch_next(self, retry_policy): HTTP_STATUS_CODES = [200, 429, 503, 504] - retry_policy.command_type = CommandType.FETCH_RESULTS_INLINE_FETCH_NEXT + retry_policy.is_retryable = False for status_code in HTTP_STATUS_CODES: assert not retry_policy.is_retry("METHOD_NAME", status_code=status_code)