Skip to content

Commit

Permalink
fixed remaining failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
toluaina committed Dec 3, 2023
1 parent e1328a6 commit 5a84436
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 139 deletions.
34 changes: 19 additions & 15 deletions examples/book_view/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,36 @@ def setup(config: str) -> None:
create_database(database)
create_schema(database, schema)
with pg_engine(database) as engine:
engine = engine.connect().execution_options(
schema_translate_map={None: schema}
)
Base.metadata.schema = schema
Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)

metadata = sa.MetaData(schema=schema)
metadata.reflect(engine, views=True)

book_model = metadata.tables[f"{schema}.book"]
engine.execute(
CreateView(
schema,
"book_view",
book_model.select(),

with engine.connect() as conn:
conn.execute(
CreateView(
schema,
"book_view",
book_model.select(),
)
)
)
conn.commit()

publisher_model = metadata.tables[f"{schema}.publisher"]
engine.execute(
CreateView(
schema,
"publisher_view",
publisher_model.select(),

with engine.connect() as conn:
conn.execute(
CreateView(
schema,
"publisher_view",
publisher_model.select(),
)
)
)
conn.commit()


@click.command()
Expand Down
14 changes: 5 additions & 9 deletions pgsync/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,17 +580,17 @@ def create_view(
def drop_view(self, schema: str) -> None:
"""Drop a view."""
logger.debug(f"Dropping view: {schema}.{MATERIALIZED_VIEW}")
self.engine.execute(DropView(schema, MATERIALIZED_VIEW))
with self.engine.connect() as conn:
conn.execute(DropView(schema, MATERIALIZED_VIEW))
logger.debug(f"Dropped view: {schema}.{MATERIALIZED_VIEW}")

def refresh_view(
self, name: str, schema: str, concurrently: bool = False
) -> None:
"""Refresh a materialized view."""
logger.debug(f"Refreshing view: {schema}.{name}")
self.engine.execute(
RefreshView(schema, name, concurrently=concurrently)
)
with self.engine.connect() as conn:
conn.execute(RefreshView(schema, name, concurrently=concurrently))
logger.debug(f"Refreshed view: {schema}.{name}")

# Triggers...
Expand Down Expand Up @@ -1034,14 +1034,12 @@ def pg_execute(
statement: sa.sql.Select,
values: Optional[list] = None,
options: Optional[dict] = None,
commit: bool = True,
) -> None:
with engine.connect() as conn:
if options:
conn = conn.execution_options(**options)
conn.execute(statement, values)
if commit:
conn.commit()
conn.commit()


def create_schema(database: str, schema: str, echo: bool = False) -> None:
Expand All @@ -1060,7 +1058,6 @@ def create_database(database: str, echo: bool = False) -> None:
engine,
sa.text(f'CREATE DATABASE "{database}"'),
options={"isolation_level": "AUTOCOMMIT"},
commit=False,
)
logger.debug(f"Created database: {database}")

Expand All @@ -1073,7 +1070,6 @@ def drop_database(database: str, echo: bool = False) -> None:
engine,
sa.text(f'DROP DATABASE IF EXISTS "{database}"'),
options={"isolation_level": "AUTOCOMMIT"},
commit=False,
)

logger.debug(f"Dropped database: {database}")
Expand Down
6 changes: 3 additions & 3 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
#
async-timeout==4.0.3
# via redis
boto3==1.29.6
boto3==1.33.6
# via -r requirements/base.in
botocore==1.32.6
botocore==1.33.6
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -69,7 +69,7 @@ requests==2.31.0
# requests-aws4auth
requests-aws4auth==1.2.3
# via -r requirements/base.in
s3transfer==0.7.0
s3transfer==0.8.2
# via boto3
six==1.16.0
# via
Expand Down
10 changes: 5 additions & 5 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ async-timeout==4.0.3
# via redis
black==23.11.0
# via -r requirements/dev.in
boto3==1.29.6
boto3==1.33.6
# via -r requirements/base.in
botocore==1.32.6
botocore==1.33.6
# via
# boto3
# s3transfer
Expand Down Expand Up @@ -51,7 +51,7 @@ filelock==3.13.1
# via virtualenv
flake8==6.1.0
# via -r requirements/dev.in
freezegun==1.2.2
freezegun==1.3.0
# via -r requirements/dev.in
greenlet==3.0.1
# via sqlalchemy
Expand Down Expand Up @@ -133,7 +133,7 @@ requests==2.31.0
# requests-aws4auth
requests-aws4auth==1.2.3
# via -r requirements/base.in
s3transfer==0.7.0
s3transfer==0.8.2
# via boto3
six==1.16.0
# via
Expand All @@ -160,7 +160,7 @@ urllib3==1.26.18
# elastic-transport
# opensearch-py
# requests
virtualenv==20.24.7
virtualenv==20.25.0
# via pre-commit

# The following packages are considered to be unsafe in a requirements file:
Expand Down
33 changes: 7 additions & 26 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,45 +433,26 @@ def model_mapping(
@pytest.fixture(scope="session")
def table_creator(base, connection, model_mapping):
sa.orm.configure_mappers()
base.metadata.create_all(connection.engine)
with connection.engine.connect() as conn:
base.metadata.create_all(connection.engine)
conn.commit()
pg_base = Base(connection.engine.url.database)
pg_base.create_triggers(
connection.engine.url.database,
DEFAULT_SCHEMA,
)
pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb")
pg_base.create_replication_slot(f"{connection.engine.url.database}_testdb")
try:
yield
finally:
pg_base.drop_replication_slot(
f"{connection.engine.url.database}_testdb"
)
yield
pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb")
with connection.engine.connect() as conn:
base.metadata.drop_all(connection.engine)

conn.commit()
try:
os.unlink(f".{connection.engine.url.database}_testdb")
except (OSError, FileNotFoundError):
pass

# sa.orm.configure_mappers()
# with connection.engine as engine:
# base.metadata.create_all(engine)
# pg_base = Base(connection.engine.url.database)
# pg_base.create_triggers(
# connection.engine.url.database,
# DEFAULT_SCHEMA,
# )
# pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb")
# pg_base.create_replication_slot(f"{connection.engine.url.database}_testdb")
# yield
# pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb")
# base.metadata.drop_all(connection)
# try:
# os.unlink(f".{connection.engine.url.database}_testdb")
# except (OSError, FileNotFoundError):
# pass


@pytest.fixture(scope="session")
def dataset(
Expand Down
36 changes: 17 additions & 19 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest
import sqlalchemy as sa
from mock import ANY, call, patch
from mock import call, patch

from pgsync.base import (
_pg_engine,
Expand Down Expand Up @@ -374,28 +374,26 @@ def test_drop_extension(
mock_pg_execute.call_args_list == calls

@patch("pgsync.base.logger")
@patch("pgsync.sync.Base.engine")
def test_drop_view(self, mock_engine, mock_logger, connection):
def test_drop_view(self, mock_logger, connection):
pg_base = Base(connection.engine.url.database)
pg_base.drop_view("public")
calls = [
call("Dropping view: public._view"),
call("Dropped view: public._view"),
]
assert mock_logger.debug.call_args_list == calls
mock_engine.execute.assert_called_once_with(ANY)
with patch("pgsync.sync.Base.engine"):
pg_base.drop_view("public")
calls = [
call("Dropping view: public._view"),
call("Dropped view: public._view"),
]
assert mock_logger.debug.call_args_list == calls

@patch("pgsync.base.logger")
@patch("pgsync.sync.Base.engine")
def test_refresh_view(self, mock_engine, mock_logger, connection):
def test_refresh_view(self, mock_logger, connection):
pg_base = Base(connection.engine.url.database)
pg_base.refresh_view("foo", "public", concurrently=True)
calls = [
call("Refreshing view: public.foo"),
call("Refreshed view: public.foo"),
]
assert mock_logger.debug.call_args_list == calls
mock_engine.execute.assert_called_once_with(ANY)
with patch("pgsync.sync.Base.engine"):
pg_base.refresh_view("foo", "public", concurrently=True)
calls = [
call("Refreshing view: public.foo"),
call("Refreshed view: public.foo"),
]
assert mock_logger.debug.call_args_list == calls

def test_parse_value(self, connection):
pg_base = Base(connection.engine.url.database)
Expand Down
3 changes: 0 additions & 3 deletions tests/test_sync_nested_children.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,6 @@ def data(

with subtransactions(session):
conn = session.connection().engine.connect().connection
conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = conn.cursor()
channel = sync.database
cursor.execute(f"UNLISTEN {channel}")
Expand Down
7 changes: 0 additions & 7 deletions tests/test_sync_root.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests for `pgsync` package."""
import mock
import psycopg2
import pytest

from pgsync.base import subtransactions
Expand Down Expand Up @@ -47,9 +46,6 @@ def data(self, sync, book_cls, publisher_cls):

with subtransactions(session):
conn = session.connection().engine.connect().connection
conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = conn.cursor()
channel = sync.database
cursor.execute(f"UNLISTEN {channel}")
Expand All @@ -67,9 +63,6 @@ def data(self, sync, book_cls, publisher_cls):

with subtransactions(session):
conn = session.connection().engine.connect().connection
conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = conn.cursor()
channel = session.connection().engine.url.database
cursor.execute(f"UNLISTEN {channel}")
Expand Down
5 changes: 3 additions & 2 deletions tests/test_trigger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Trigger tests."""
import pytest
import sqlalchemy as sa

from pgsync.base import Base
from pgsync.trigger import CREATE_TRIGGER_TEMPLATE
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_trigger_primary_key_function(self, connection):
f"JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY(indkey) " # noqa E501
f"WHERE indrelid = '{table_name}'::regclass AND indisprimary"
)
rows = pg_base.fetchall(query)[0]
rows = pg_base.fetchall(sa.text(query))[0]
assert list(rows)[0] == primary_keys

def test_trigger_foreign_key_function(self, connection):
Expand All @@ -129,7 +130,7 @@ def test_trigger_foreign_key_function(self, connection):
f"WHERE constraint_catalog=current_catalog AND "
f"table_name='{table_name}' AND position_in_unique_constraint NOTNULL " # noqa E501
)
rows = pg_base.fetchall(query)[0]
rows = pg_base.fetchall(sa.text(query))[0]
if rows[0]:
assert sorted(rows[0]) == sorted(foreign_keys)
else:
Expand Down
7 changes: 0 additions & 7 deletions tests/test_unique_behaviour.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Tests for `pgsync` package."""

import psycopg2
import pytest

from pgsync.base import subtransactions
Expand Down Expand Up @@ -45,9 +44,6 @@ def data(
]
with subtransactions(session):
conn = session.connection().engine.connect().connection
conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = conn.cursor()
channel = sync.database
cursor.execute(f"UNLISTEN {channel}")
Expand All @@ -72,9 +68,6 @@ def data(

with subtransactions(session):
conn = session.connection().engine.connect().connection
conn.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = conn.cursor()
channel = session.connection().engine.url.database
cursor.execute(f"UNLISTEN {channel}")
Expand Down
Loading

0 comments on commit 5a84436

Please sign in to comment.