diff --git a/examples/book_view/schema.py b/examples/book_view/schema.py index e311f441..bce00a8f 100644 --- a/examples/book_view/schema.py +++ b/examples/book_view/schema.py @@ -50,9 +50,7 @@ def setup(config: str) -> None: create_database(database) create_schema(database, schema) with pg_engine(database) as engine: - engine = engine.connect().execution_options( - schema_translate_map={None: schema} - ) + Base.metadata.schema = schema Base.metadata.drop_all(engine) Base.metadata.create_all(engine) @@ -60,22 +58,28 @@ def setup(config: str) -> None: metadata.reflect(engine, views=True) book_model = metadata.tables[f"{schema}.book"] - engine.execute( - CreateView( - schema, - "book_view", - book_model.select(), + + with engine.connect() as conn: + conn.execute( + CreateView( + schema, + "book_view", + book_model.select(), + ) ) - ) + conn.commit() publisher_model = metadata.tables[f"{schema}.publisher"] - engine.execute( - CreateView( - schema, - "publisher_view", - publisher_model.select(), + + with engine.connect() as conn: + conn.execute( + CreateView( + schema, + "publisher_view", + publisher_model.select(), + ) ) - ) + conn.commit() @click.command() diff --git a/pgsync/base.py b/pgsync/base.py index e231fdd4..456d3e60 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -580,7 +580,8 @@ def create_view( def drop_view(self, schema: str) -> None: """Drop a view.""" logger.debug(f"Dropping view: {schema}.{MATERIALIZED_VIEW}") - self.engine.execute(DropView(schema, MATERIALIZED_VIEW)) + with self.engine.connect() as conn: + conn.execute(DropView(schema, MATERIALIZED_VIEW)) logger.debug(f"Dropped view: {schema}.{MATERIALIZED_VIEW}") def refresh_view( @@ -588,9 +589,8 @@ def refresh_view( ) -> None: """Refresh a materialized view.""" logger.debug(f"Refreshing view: {schema}.{name}") - self.engine.execute( - RefreshView(schema, name, concurrently=concurrently) - ) + with self.engine.connect() as conn: + conn.execute(RefreshView(schema, name, concurrently=concurrently)) logger.debug(f"Refreshed view: {schema}.{name}") # Triggers... @@ -1034,14 +1034,12 @@ def pg_execute( statement: sa.sql.Select, values: Optional[list] = None, options: Optional[dict] = None, - commit: bool = True, ) -> None: with engine.connect() as conn: if options: conn = conn.execution_options(**options) conn.execute(statement, values) - if commit: - conn.commit() + conn.commit() def create_schema(database: str, schema: str, echo: bool = False) -> None: @@ -1060,7 +1058,6 @@ def create_database(database: str, echo: bool = False) -> None: engine, sa.text(f'CREATE DATABASE "{database}"'), options={"isolation_level": "AUTOCOMMIT"}, - commit=False, ) logger.debug(f"Created database: {database}") @@ -1073,7 +1070,6 @@ def drop_database(database: str, echo: bool = False) -> None: engine, sa.text(f'DROP DATABASE IF EXISTS "{database}"'), options={"isolation_level": "AUTOCOMMIT"}, - commit=False, ) logger.debug(f"Dropped database: {database}") diff --git a/requirements/base.txt b/requirements/base.txt index 88cda98d..e0386cc9 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,9 +6,9 @@ # async-timeout==4.0.3 # via redis -boto3==1.29.6 +boto3==1.33.6 # via -r requirements/base.in -botocore==1.32.6 +botocore==1.33.6 # via # boto3 # s3transfer @@ -69,7 +69,7 @@ requests==2.31.0 # requests-aws4auth requests-aws4auth==1.2.3 # via -r requirements/base.in -s3transfer==0.7.0 +s3transfer==0.8.2 # via boto3 six==1.16.0 # via diff --git a/requirements/dev.txt b/requirements/dev.txt index 3354834b..c7b8279b 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -8,9 +8,9 @@ async-timeout==4.0.3 # via redis black==23.11.0 # via -r requirements/dev.in -boto3==1.29.6 +boto3==1.33.6 # via -r requirements/base.in -botocore==1.32.6 +botocore==1.33.6 # via # boto3 # s3transfer @@ -51,7 +51,7 @@ filelock==3.13.1 # via virtualenv flake8==6.1.0 # via -r requirements/dev.in -freezegun==1.2.2 +freezegun==1.3.0 # via -r requirements/dev.in greenlet==3.0.1 # via sqlalchemy @@ -133,7 +133,7 @@ requests==2.31.0 # requests-aws4auth requests-aws4auth==1.2.3 # via -r requirements/base.in -s3transfer==0.7.0 +s3transfer==0.8.2 # via boto3 six==1.16.0 # via @@ -160,7 +160,7 @@ urllib3==1.26.18 # elastic-transport # opensearch-py # requests -virtualenv==20.24.7 +virtualenv==20.25.0 # via pre-commit # The following packages are considered to be unsafe in a requirements file: diff --git a/tests/conftest.py b/tests/conftest.py index 90ec349c..ea019f22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -433,7 +433,9 @@ def model_mapping( @pytest.fixture(scope="session") def table_creator(base, connection, model_mapping): sa.orm.configure_mappers() - base.metadata.create_all(connection.engine) + with connection.engine.connect() as conn: + base.metadata.create_all(connection.engine) + conn.commit() pg_base = Base(connection.engine.url.database) pg_base.create_triggers( connection.engine.url.database, @@ -441,37 +443,16 @@ def table_creator(base, connection, model_mapping): ) pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb") pg_base.create_replication_slot(f"{connection.engine.url.database}_testdb") - try: - yield - finally: - pg_base.drop_replication_slot( - f"{connection.engine.url.database}_testdb" - ) + yield + pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb") + with connection.engine.connect() as conn: base.metadata.drop_all(connection.engine) - + conn.commit() try: os.unlink(f".{connection.engine.url.database}_testdb") except (OSError, FileNotFoundError): pass - # sa.orm.configure_mappers() - # with connection.engine as engine: - # base.metadata.create_all(engine) - # pg_base = Base(connection.engine.url.database) - # pg_base.create_triggers( - # connection.engine.url.database, - # DEFAULT_SCHEMA, - # ) - # pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb") - # pg_base.create_replication_slot(f"{connection.engine.url.database}_testdb") - # yield - # pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb") - # base.metadata.drop_all(connection) - # try: - # os.unlink(f".{connection.engine.url.database}_testdb") - # except (OSError, FileNotFoundError): - # pass - @pytest.fixture(scope="session") def dataset( diff --git a/tests/test_base.py b/tests/test_base.py index b5fad2e8..48955e18 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -2,7 +2,7 @@ import pytest import sqlalchemy as sa -from mock import ANY, call, patch +from mock import call, patch from pgsync.base import ( _pg_engine, @@ -374,28 +374,26 @@ def test_drop_extension( mock_pg_execute.call_args_list == calls @patch("pgsync.base.logger") - @patch("pgsync.sync.Base.engine") - def test_drop_view(self, mock_engine, mock_logger, connection): + def test_drop_view(self, mock_logger, connection): pg_base = Base(connection.engine.url.database) - pg_base.drop_view("public") - calls = [ - call("Dropping view: public._view"), - call("Dropped view: public._view"), - ] - assert mock_logger.debug.call_args_list == calls - mock_engine.execute.assert_called_once_with(ANY) + with patch("pgsync.sync.Base.engine"): + pg_base.drop_view("public") + calls = [ + call("Dropping view: public._view"), + call("Dropped view: public._view"), + ] + assert mock_logger.debug.call_args_list == calls @patch("pgsync.base.logger") - @patch("pgsync.sync.Base.engine") - def test_refresh_view(self, mock_engine, mock_logger, connection): + def test_refresh_view(self, mock_logger, connection): pg_base = Base(connection.engine.url.database) - pg_base.refresh_view("foo", "public", concurrently=True) - calls = [ - call("Refreshing view: public.foo"), - call("Refreshed view: public.foo"), - ] - assert mock_logger.debug.call_args_list == calls - mock_engine.execute.assert_called_once_with(ANY) + with patch("pgsync.sync.Base.engine"): + pg_base.refresh_view("foo", "public", concurrently=True) + calls = [ + call("Refreshing view: public.foo"), + call("Refreshed view: public.foo"), + ] + assert mock_logger.debug.call_args_list == calls def test_parse_value(self, connection): pg_base = Base(connection.engine.url.database) diff --git a/tests/test_sync_nested_children.py b/tests/test_sync_nested_children.py index 8d9a81dc..bece03fc 100644 --- a/tests/test_sync_nested_children.py +++ b/tests/test_sync_nested_children.py @@ -213,9 +213,6 @@ def data( with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_sync_root.py b/tests/test_sync_root.py index 1e51d124..51d986b4 100644 --- a/tests/test_sync_root.py +++ b/tests/test_sync_root.py @@ -1,6 +1,5 @@ """Tests for `pgsync` package.""" import mock -import psycopg2 import pytest from pgsync.base import subtransactions @@ -47,9 +46,6 @@ def data(self, sync, book_cls, publisher_cls): with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") @@ -67,9 +63,6 @@ def data(self, sync, book_cls, publisher_cls): with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = session.connection().engine.url.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_trigger.py b/tests/test_trigger.py index ab9256ab..f86492b4 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -1,5 +1,6 @@ """Trigger tests.""" import pytest +import sqlalchemy as sa from pgsync.base import Base from pgsync.trigger import CREATE_TRIGGER_TEMPLATE @@ -107,7 +108,7 @@ def test_trigger_primary_key_function(self, connection): f"JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY(indkey) " # noqa E501 f"WHERE indrelid = '{table_name}'::regclass AND indisprimary" ) - rows = pg_base.fetchall(query)[0] + rows = pg_base.fetchall(sa.text(query))[0] assert list(rows)[0] == primary_keys def test_trigger_foreign_key_function(self, connection): @@ -129,7 +130,7 @@ def test_trigger_foreign_key_function(self, connection): f"WHERE constraint_catalog=current_catalog AND " f"table_name='{table_name}' AND position_in_unique_constraint NOTNULL " # noqa E501 ) - rows = pg_base.fetchall(query)[0] + rows = pg_base.fetchall(sa.text(query))[0] if rows[0]: assert sorted(rows[0]) == sorted(foreign_keys) else: diff --git a/tests/test_unique_behaviour.py b/tests/test_unique_behaviour.py index 84bb6c65..5e3f6d21 100644 --- a/tests/test_unique_behaviour.py +++ b/tests/test_unique_behaviour.py @@ -1,6 +1,5 @@ """Tests for `pgsync` package.""" -import psycopg2 import pytest from pgsync.base import subtransactions @@ -45,9 +44,6 @@ def data( ] with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") @@ -72,9 +68,6 @@ def data( with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = session.connection().engine.url.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_view.py b/tests/test_view.py index 86d91914..36cc7fe6 100644 --- a/tests/test_view.py +++ b/tests/test_view.py @@ -35,6 +35,10 @@ def data(self, sync, book_cls): with subtransactions(session): session.add_all(books) yield books + + with subtransactions(session): + session.query(book_cls).delete() + session.connection().engine.connect().close() session.connection().engine.dispose() sync.search_client.close() @@ -149,54 +153,51 @@ def test_drop_view(self, connection): def test_refresh_view(self, connection, sync, book_cls, data): """Test refresh materialized view.""" view = "test_view_refresh" - # pg_base = Base(connection.engine.url.database) + pg_base = Base(connection.engine.url.database) - # model = pg_base.models("book", "public") - # statement = sa.select([model.c.isbn]).select_from(model) - # with connection.engine.connect() as conn: - # conn.execute( - # CreateView(DEFAULT_SCHEMA, view, statement, materialized=True) - # ) - # conn.commit() - # assert [ - # result.isbn - # for result in connection.engine.execute( - # sa.text(f"SELECT * FROM {view}") - # ) - # ][0] == "abc" + model = pg_base.models("book", "public") + statement = sa.select(*[model.c.isbn]).select_from(model) + with connection.engine.connect() as conn: + conn.execute( + CreateView(DEFAULT_SCHEMA, view, statement, materialized=True) + ) + conn.commit() - # session = sync.session - # with subtransactions(session): - # session.execute( - # book_cls.__table__.update() - # .where(book_cls.__table__.c.isbn == "abc") - # .values(isbn="xyz") - # ) + with connection.engine.connect() as conn: + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "abc" - # # the value should still be abc - # assert [ - # result.isbn - # for result in connection.engine.execute( - # sa.text(f"SELECT * FROM {view}") - # ) - # ][0] == "abc" + session = sync.session + with subtransactions(session): + session.execute( + book_cls.__table__.update() + .where(book_cls.__table__.c.isbn == "abc") + .values(isbn="xyz") + ) + + with connection.engine.connect() as conn: + # the value should still be abc + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "abc" + + with connection.engine.connect() as conn: + conn.execute(RefreshView(DEFAULT_SCHEMA, view)) + conn.commit() - # with connection.engine.connect() as conn: - # conn.execute(RefreshView(DEFAULT_SCHEMA, view)) - # conn.commit() + with connection.engine.connect() as conn: + # the value should now be xyz + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "xyz" - # # the value should now be xyz - # assert [ - # result.isbn - # for result in connection.engine.execute( - # sa.text(f"SELECT * FROM {view}") - # ) - # ][0] == "xyz" - # with connection.engine.connect() as conn: - # conn.execute( - # DropView(DEFAULT_SCHEMA, view, materialized=True) - # ) - # conn.commit() + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=True)) + conn.commit() @pytest.mark.usefixtures("table_creator") def test_index(self, connection, sync, book_cls, data):