Skip to content

Commit

Permalink
Updates from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
odeke-em committed Jan 9, 2025
1 parent 8581115 commit ea2551a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 48 deletions.
69 changes: 33 additions & 36 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,46 +699,43 @@ def execute_partitioned_dml(
)

def execute_pdml():
def do_execute_pdml(session, span):
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

txn_selector = TransactionSelector(id=txn.id)

request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

return result_set.stats.row_count_lower_bound

with trace_call(
"CloudSpanner.Database.execute_partitioned_pdml",
observability_options=self.observability_options,
) as span:
with SessionCheckout(self._pool) as session:
return do_execute_pdml(session, span)
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
)

txn_selector = TransactionSelector(id=txn.id)

request = ExecuteSqlRequest(
session=session.name,
sql=dml,
params=params_pb,
param_types=param_types,
query_options=query_options,
request_options=request_options,
)
method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
)

iterator = _restart_on_unavailable(
method=method,
trace_name="CloudSpanner.ExecuteStreamingSql",
request=request,
transaction_selector=txn_selector,
observability_options=self.observability_options,
)

result_set = StreamedResultSet(iterator)
list(result_set) # consume all partials

return result_set.stats.row_count_lower_bound

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

Expand Down Expand Up @@ -1531,7 +1528,7 @@ def process_read_batch(
:rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet`
:returns: a result set instance which can be used to consume rows.
"""
observability_options = self.observability_options or {}
observability_options = self.observability_options
with trace_call(
f"CloudSpanner.{type(self).__name__}.process_read_batch",
observability_options=observability_options,
Expand Down
1 change: 0 additions & 1 deletion google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,6 @@ def run_in_transaction(self, func, *args, **kw):
) as span:
while True:
if self._transaction is None:
add_span_event(span, "Creating Transaction")
txn = self.transaction()
txn.transaction_tag = transaction_tag
txn.exclude_txn_from_change_streams = (
Expand Down
1 change: 1 addition & 0 deletions google/cloud/spanner_v1/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def partition_read(

with trace_call(
f"CloudSpanner.{type(self).__name__}.partition_read",
self._session,
extra_attributes=trace_attributes,
observability_options=getattr(database, "observability_options", None),
):
Expand Down
4 changes: 1 addition & 3 deletions tests/system/test_observability_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,10 @@ def select_in_txn(txn):
("Waiting for a session to become available", {"kind": "BurstyPool"}),
("No sessions available in pool. Creating session", {"kind": "BurstyPool"}),
("Creating Session", {}),
("Creating Transaction", {}),
(
"Transaction was aborted in user operation, retrying",
{"delay_seconds": "EPHEMERAL", "cause": "EPHEMERAL", "attempt": 1},
),
("Creating Transaction", {}),
("Starting Commit", {}),
("Commit Done", {}),
]
Expand Down Expand Up @@ -283,7 +281,7 @@ def finished_spans_statuses(trace_exporter):
not HAS_OTEL_INSTALLED,
reason="Tracing requires OpenTelemetry",
)
def test_database_partitioned():
def test_database_partitioned_error():
from opentelemetry.trace.status import StatusCode

db, trace_exporter = create_db_trace_exporter()
Expand Down
12 changes: 4 additions & 8 deletions tests/unit/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,15 +1233,12 @@ def test_partition_read_other_error(self):
if not HAS_OPENTELEMETRY_INSTALLED:
return

want_span_attributes = dict(
BASE_ATTRIBUTES,
table_id=TABLE_NAME,
columns=tuple(COLUMNS),
)
self.assertSpanAttributes(
"CloudSpanner._Derived.partition_read",
status=StatusCode.ERROR,
attributes=want_span_attributes,
attributes=dict(
BASE_ATTRIBUTES, table_id=TABLE_NAME, columns=tuple(COLUMNS)
),
)

def test_partition_read_w_retry(self):
Expand Down Expand Up @@ -1379,11 +1376,10 @@ def _partition_query_helper(
timeout=timeout,
)

attributes = dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM})
self.assertSpanAttributes(
"CloudSpanner._Derived.partition_query",
status=StatusCode.OK,
attributes=attributes,
attributes=dict(BASE_ATTRIBUTES, **{"db.statement": SQL_QUERY_WITH_PARAM}),
)

def test_partition_query_other_error(self):
Expand Down

0 comments on commit ea2551a

Please sign in to comment.