Skip to content

Commit

Permalink
Avoid an extra transaction, in some cases
Browse files Browse the repository at this point in the history
Signed-off-by: Rick Elrod <[email protected]>
  • Loading branch information
relrod committed Aug 6, 2024
1 parent 40bc80d commit d0aef26
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 15 deletions.
14 changes: 14 additions & 0 deletions ansible_base/lib/utils/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from contextlib import contextmanager

from django.db import transaction


@contextmanager
def ensure_transaction():
needs_new_transaction = not transaction.get_connection().in_atomic_block

if needs_new_transaction:
with transaction.atomic():
yield
else:
yield
6 changes: 3 additions & 3 deletions ansible_base/resource_registry/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from django.apps import AppConfig
from django.conf import settings
from django.db import transaction
from django.db.models import TextField, signals
from django.db.models.functions import Cast
from django.db.utils import IntegrityError

import ansible_base.lib.checks # noqa: F401 - register checks
from ansible_base.lib.utils.db import ensure_transaction

logger = logging.getLogger('ansible_base.resource_registry.apps')

Expand Down Expand Up @@ -120,7 +120,7 @@ def connect_resource_signals(sender, **kwargs):

# Avoid late binding issues
def save(self, *args, _original_save=cls._original_save, **kwargs):
with transaction.atomic():
with ensure_transaction():
# We need to know if this is a new object before we save it
action = "create" if self._state.adding else "update"
# Save so we get an ansible_id if it's a new object
Expand All @@ -135,7 +135,7 @@ def save(self, *args, _original_save=cls._original_save, **kwargs):

# Avoid late binding issues
def delete(self, *args, _original_delete=cls._original_delete, **kwargs):
with transaction.atomic():
with ensure_transaction():
_original_delete(self, *args, **kwargs)
handlers.sync_to_resource_server(self, "delete")

Expand Down
19 changes: 7 additions & 12 deletions test_app/tests/resource_registry/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def setup(self, request, settings):
settings.DISABLE_RESOURCE_SERVER_SYNC = is_disabled
apps.connect_resource_signals(sender=None)

@pytest.mark.django_db
@pytest.mark.django_db(transaction=True)
@pytest.mark.parametrize('action', ['create', 'update', 'delete'])
def test_sync_to_resource_server_happy_path(self, user, action):
"""
Expand Down Expand Up @@ -100,9 +100,8 @@ def test_sync_to_resource_server_happy_path(self, user, action):
client_method = getattr(get_resource_server_client.return_value, f'{action}_resource')
client_method.assert_called()

# The whole thing is wrapped in a transaction/savepoint
assert queries.captured_queries[0]['sql'].startswith('SAVEPOINT'), queries.captured_queries[0]['sql']
assert queries.captured_queries[-1]['sql'].startswith('RELEASE SAVEPOINT'), queries.captured_queries[-1]['sql']
# The whole thing is wrapped in a transaction
assert queries.captured_queries[-1]['sql'] == 'COMMIT'

@pytest.mark.django_db
@pytest.mark.parametrize('anon', [AnonymousUser(), None])
Expand Down Expand Up @@ -140,7 +139,7 @@ def test_sync_to_resource_server_no_resource(self, user, nullify_resource):
# We bail out if we don't have a resource
get_resource_server_client.assert_not_called()

@pytest.mark.django_db
@pytest.mark.django_db(transaction=True)
def test_sync_to_resource_server_exception_during_sync(self, user):
"""
We get an exception when trying to sync (e.g. the server gives us a 500, or
Expand All @@ -153,11 +152,9 @@ def test_sync_to_resource_server_exception_during_sync(self, user):
with pytest.raises(ValidationError, match="Failed to sync resource"):
Organization.objects.create(name='Hello')

# The last two queries should be a rollback
assert queries.captured_queries[-2]['sql'].startswith('ROLLBACK'), queries.captured_queries[-2]['sql']
assert queries.captured_queries[-1]['sql'].startswith('RELEASE SAVEPOINT'), queries.captured_queries[-1]['sql']
assert queries.captured_queries[-1]['sql'] == 'ROLLBACK'

@pytest.mark.django_db
@pytest.mark.django_db(transaction=True)
def test_sync_to_resource_server_exception_during_save(self, user, organization):
"""
If we get an exception during .save(), the transaction should still roll back
Expand All @@ -170,9 +167,7 @@ def test_sync_to_resource_server_exception_during_save(self, user, organization)
org = Organization(name=organization.name)
org.save()

# The last two queries should be a rollback
assert queries.captured_queries[-2]['sql'].startswith('ROLLBACK'), queries.captured_queries[-2]['sql']
assert queries.captured_queries[-1]['sql'].startswith('RELEASE SAVEPOINT'), queries.captured_queries[-1]['sql']
assert queries.captured_queries[-1]['sql'] == 'ROLLBACK'

@pytest.mark.parametrize(
'new_settings,should_sync',
Expand Down

0 comments on commit d0aef26

Please sign in to comment.