Skip to content

Commit

Permalink
Merge branch 'main' into trace-update-cases-from-review
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em authored Jan 9, 2025
2 parents 423e5bc + 0887eb4 commit d32aab2
Show file tree
Hide file tree
Showing 12 changed files with 247 additions and 72 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@ system_tests/local_test_setup
# Make sure a generated file isn't accidentally committed.
pylintrc
pylintrc.test


# Ignore coverage files
.coverage*
2 changes: 1 addition & 1 deletion google/cloud/spanner_dbapi/transaction_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode
from google.cloud.spanner_dbapi.exceptions import RetryAborted
from google.cloud.spanner_v1.session import _get_retry_delay
from google.cloud.spanner_v1._helpers import _get_retry_delay

if TYPE_CHECKING:
from google.cloud.spanner_dbapi import Connection, Cursor
Expand Down
75 changes: 75 additions & 0 deletions google/cloud/spanner_v1/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@
from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper

from google.api_core import datetime_helpers
from google.api_core.exceptions import Aborted
from google.cloud._helpers import _date_from_iso8601_date
from google.cloud.spanner_v1 import TypeCode
from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import JsonObject
from google.cloud.spanner_v1.request_id_header import with_request_id
from google.rpc.error_details_pb2 import RetryInfo

import random

# Validation error messages
NUMERIC_MAX_SCALE_ERR_MSG = (
Expand Down Expand Up @@ -460,6 +464,23 @@ def _metadata_with_prefix(prefix, **kw):
return [("google-cloud-resource-prefix", prefix)]


def _retry_on_aborted_exception(
func,
deadline,
):
"""
Handles retry logic for Aborted exceptions, considering the deadline.
"""
attempts = 0
while True:
try:
attempts += 1
return func()
except Aborted as exc:
_delay_until_retry(exc, deadline=deadline, attempts=attempts)
continue


def _retry(
func,
retry_count=5,
Expand Down Expand Up @@ -529,6 +550,60 @@ def _metadata_with_leader_aware_routing(value, **kw):
return ("x-goog-spanner-route-to-leader", str(value).lower())


def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.
Detect retryable abort, and impose server-supplied delay.
:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction
:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.
:type attempts: int
:param attempts: number of call retries
"""

cause = exc.errors[0]
now = time.time()
if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.
:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction
:rtype: float
:returns: seconds to wait before retrying the transaction.
:type attempts: int
:param attempts: number of call retries
"""
if hasattr(cause, "trailing_metadata"):
metadata = dict(cause.trailing_metadata())
else:
metadata = {}
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()


class AtomicCounter:
def __init__(self, start_value=0):
self.__lock = threading.Lock()
Expand Down
16 changes: 13 additions & 3 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,12 @@
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
from google.cloud.spanner_v1._helpers import _retry
from google.cloud.spanner_v1._helpers import _retry_on_aborted_exception
from google.cloud.spanner_v1._helpers import _check_rst_stream_error
from google.api_core.exceptions import InternalServerError
import time

DEFAULT_RETRY_TIMEOUT_SECS = 30


class _BatchBase(_SessionWrapper):
Expand Down Expand Up @@ -162,6 +166,7 @@ def commit(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kwargs,
):
"""Commit mutations to the database.
Expand Down Expand Up @@ -227,9 +232,12 @@ def commit(
request=request,
metadata=metadata,
)
response = _retry(
deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
response = _retry_on_aborted_exception(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
deadline=deadline,
)
self.committed = response.commit_timestamp
self.commit_stats = response.commit_stats
Expand Down Expand Up @@ -348,7 +356,9 @@ def batch_write(self, request_options=None, exclude_txn_from_change_streams=Fals
)
response = _retry(
method,
allowed_exceptions={InternalServerError: _check_rst_stream_error},
allowed_exceptions={
InternalServerError: _check_rst_stream_error,
},
)
self.committed = True
return response
Expand Down
10 changes: 9 additions & 1 deletion google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,7 @@ def batch(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
"""Return an object which wraps a batch.
Expand Down Expand Up @@ -805,7 +806,11 @@ def batch(
:returns: new wrapper
"""
return BatchCheckout(
self, request_options, max_commit_delay, exclude_txn_from_change_streams
self,
request_options,
max_commit_delay,
exclude_txn_from_change_streams,
**kw,
)

def mutation_groups(self):
Expand Down Expand Up @@ -1166,6 +1171,7 @@ def __init__(
request_options=None,
max_commit_delay=None,
exclude_txn_from_change_streams=False,
**kw,
):
self._database = database
self._session = self._batch = None
Expand All @@ -1177,6 +1183,7 @@ def __init__(
self._request_options = request_options
self._max_commit_delay = max_commit_delay
self._exclude_txn_from_change_streams = exclude_txn_from_change_streams
self._kw = kw

def __enter__(self):
"""Begin ``with`` block."""
Expand All @@ -1197,6 +1204,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
request_options=self._request_options,
max_commit_delay=self._max_commit_delay,
exclude_txn_from_change_streams=self._exclude_txn_from_change_streams,
**self._kw,
)
finally:
if self._database.log_commit_stats and self._batch.commit_stats:
Expand Down
58 changes: 2 additions & 56 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
"""Wrapper for Cloud Spanner Session objects."""

from functools import total_ordering
import random
import time
from datetime import datetime

from google.api_core.exceptions import Aborted
from google.api_core.exceptions import GoogleAPICallError
from google.api_core.exceptions import NotFound
from google.api_core.gapic_v1 import method
from google.rpc.error_details_pb2 import RetryInfo
from google.cloud.spanner_v1._helpers import _delay_until_retry
from google.cloud.spanner_v1._helpers import _get_retry_delay

from google.cloud.spanner_v1 import ExecuteSqlRequest
from google.cloud.spanner_v1 import CreateSessionRequest
Expand Down Expand Up @@ -554,57 +554,3 @@ def run_in_transaction(self, func, *args, **kw):
extra={"commit_stats": txn.commit_stats},
)
return return_value


# Rational: this function factors out complex shared deadline / retry
# handling from two `except:` clauses.
def _delay_until_retry(exc, deadline, attempts):
"""Helper for :meth:`Session.run_in_transaction`.
Detect retryable abort, and impose server-supplied delay.
:type exc: :class:`google.api_core.exceptions.Aborted`
:param exc: exception for aborted transaction
:type deadline: float
:param deadline: maximum timestamp to continue retrying the transaction.
:type attempts: int
:param attempts: number of call retries
"""
cause = exc.errors[0]

now = time.time()

if now >= deadline:
raise

delay = _get_retry_delay(cause, attempts)
if delay is not None:
if now + delay > deadline:
raise

time.sleep(delay)


def _get_retry_delay(cause, attempts):
"""Helper for :func:`_delay_until_retry`.
:type exc: :class:`grpc.Call`
:param exc: exception for aborted transaction
:rtype: float
:returns: seconds to wait before retrying the transaction.
:type attempts: int
:param attempts: number of call retries
"""
metadata = dict(cause.trailing_metadata())
retry_info_pb = metadata.get("google.rpc.retryinfo-bin")
if retry_info_pb is not None:
retry_info = RetryInfo()
retry_info.ParseFromString(retry_info_pb)
nanos = retry_info.retry_delay.nanos
return retry_info.retry_delay.seconds + nanos / 1.0e9

return 2**attempts + random.random()
17 changes: 13 additions & 4 deletions google/cloud/spanner_v1/testing/mock_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,19 @@ def __create_transaction(
def Commit(self, request, context):
self._requests.append(request)
self.mock_spanner.pop_error(context)
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
del self.transactions[request.transaction_id]
if not request.transaction_id == b"":
tx = self.transactions[request.transaction_id]
if tx is None:
raise ValueError(f"Transaction not found: {request.transaction_id}")
tx_id = request.transaction_id
elif not request.single_use_transaction == TransactionOptions():
tx = self.__create_transaction(
request.session, request.single_use_transaction
)
tx_id = tx.id
else:
raise ValueError("Unsupported transaction type")
del self.transactions[tx_id]
return commit.CommitResponse()

def Rollback(self, request, context):
Expand Down
24 changes: 24 additions & 0 deletions tests/mockserver_tests/test_aborted_transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,30 @@ def test_run_in_transaction_batch_dml_aborted(self):
self.assertTrue(isinstance(requests[2], ExecuteBatchDmlRequest))
self.assertTrue(isinstance(requests[3], CommitRequest))

def test_batch_commit_aborted(self):
# Add an Aborted error for the Commit method on the mock server.
add_error(SpannerServicer.Commit.__name__, aborted_status())
with self.database.batch() as batch:
batch.insert(
table="Singers",
columns=("SingerId", "FirstName", "LastName"),
values=[
(1, "Marc", "Richards"),
(2, "Catalina", "Smith"),
(3, "Alice", "Trentor"),
(4, "Lea", "Martin"),
(5, "David", "Lomond"),
],
)

# Verify that the transaction was retried.
requests = self.spanner_service.requests
self.assertEqual(3, len(requests), msg=requests)
self.assertTrue(isinstance(requests[0], BatchCreateSessionsRequest))
self.assertTrue(isinstance(requests[1], CommitRequest))
# The transaction is aborted and retried.
self.assertTrue(isinstance(requests[2], CommitRequest))


def _insert_mutations(transaction: Transaction):
transaction.insert("my_table", ["col1", "col2"], ["value1", "value2"])
Expand Down
Loading

0 comments on commit d32aab2

Please sign in to comment.