From f9a37eb777745a7607aeb8326967ad09a4cd4d6e Mon Sep 17 00:00:00 2001 From: Tolu Aina <7848930+toluaina@users.noreply.github.com> Date: Tue, 5 Dec 2023 21:09:46 +0000 Subject: [PATCH] Sqlalchemy 2.x (#508) * SQLAlchemy 2.0 migration WIP * fixed remaining failing tests * more cleanup and documentatiion --- bin/parallel_sync | 2 +- examples/airbnb/schema.py | 78 ++++++++------ examples/ancestry/schema.py | 50 ++++++--- examples/book/schema.py | 166 ++++++++++++++++++----------- examples/book_view/schema.py | 64 ++++++----- examples/node/README | 2 +- examples/node/schema.py | 18 ++-- examples/quiz/schema.py | 48 +++++---- examples/schemas/schema.py | 30 +++--- examples/social/schema.py | 64 ++++++----- examples/starcraft/schema.py | 27 ++--- examples/through/schema.py | 34 +++--- pgsync/base.py | 153 +++++++++++++------------- pgsync/querybuilder.py | 10 +- pgsync/view.py | 33 +++--- plugins/sample.py | 1 + requirements/base.in | 7 +- requirements/base.txt | 33 ++---- requirements/dev.in | 2 + requirements/dev.txt | 24 +++-- tests/conftest.py | 16 ++- tests/test_base.py | 143 +++++++++++++------------ tests/test_settings.py | 2 +- tests/test_sync_nested_children.py | 3 - tests/test_sync_root.py | 7 -- tests/test_trigger.py | 5 +- tests/test_unique_behaviour.py | 7 -- tests/test_utils.py | 4 +- tests/test_view.py | 138 ++++++++++++++---------- 29 files changed, 650 insertions(+), 521 deletions(-) diff --git a/bin/parallel_sync b/bin/parallel_sync index 150e9c31..cedca84d 100755 --- a/bin/parallel_sync +++ b/bin/parallel_sync @@ -146,7 +146,7 @@ def fetch_tasks( ) page, row = read_ctid(name=name) statement: sa.sql.Select = sa.select( - [ + *[ sa.literal_column("1").label("x"), sa.literal_column("1").label("y"), sa.column("ctid"), diff --git a/examples/airbnb/schema.py b/examples/airbnb/schema.py index 55c2e557..bb320e8d 100644 --- a/examples/airbnb/schema.py +++ b/examples/airbnb/schema.py @@ -2,45 +2,49 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class User(Base): __tablename__ = "user" __table_args__ = (UniqueConstraint("email"),) - id = sa.Column(sa.Integer, primary_key=True) - email = sa.Column(sa.String, unique=True, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + email: Mapped[str] = mapped_column(sa.String, unique=True, nullable=False) class Host(Base): __tablename__ = "host" __table_args__ = (UniqueConstraint("email"),) - id = sa.Column(sa.Integer, primary_key=True) - email = sa.Column(sa.String, unique=True, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + email: Mapped[str] = mapped_column(sa.String, unique=True, nullable=False) class Country(Base): __tablename__ = "country" __table_args__ = (UniqueConstraint("name", "country_code"),) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) - country_code = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + country_code: Mapped[str] = mapped_column(sa.String, nullable=False) class City(Base): __tablename__ = "city" __table_args__ = (UniqueConstraint("name", "country_id"),) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) - country_id = sa.Column(sa.Integer, sa.ForeignKey(Country.id)) - country = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + country_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Country.id) + ) + country: Mapped[Country] = sa.orm.relationship( Country, backref=sa.orm.backref("country"), ) @@ -49,15 +53,15 @@ class City(Base): class Place(Base): __tablename__ = "place" __table_args__ = (UniqueConstraint("host_id", "address", "city_id"),) - id = sa.Column(sa.Integer, primary_key=True) - host_id = sa.Column(sa.Integer, sa.ForeignKey(Host.id)) - address = sa.Column(sa.String, nullable=False) - city_id = sa.Column(sa.Integer, sa.ForeignKey(City.id)) - host = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + host_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Host.id)) + address: Mapped[str] = mapped_column(sa.String, nullable=False) + city_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(City.id)) + host: Mapped[Host] = sa.orm.relationship( Host, backref=sa.orm.backref("host"), ) - city = sa.orm.relationship( + city: Mapped[City] = sa.orm.relationship( City, backref=sa.orm.backref("city"), ) @@ -66,25 +70,33 @@ class Place(Base): class Booking(Base): __tablename__ = "booking" __table_args__ = (UniqueConstraint("user_id", "place_id", "start_date"),) - id = sa.Column(sa.Integer, primary_key=True) - user_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) - place_id = sa.Column(sa.Integer, sa.ForeignKey(Place.id)) - start_date = sa.Column(sa.DateTime, default=datetime.now()) - end_date = sa.Column(sa.DateTime, default=datetime.now()) - price_per_night = sa.Column(sa.Float, default=0) - num_nights = sa.Column(sa.Integer, nullable=False, default=1) - user = sa.orm.relationship(User) - place = sa.orm.relationship(Place) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(User.id)) + place_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Place.id)) + start_date: Mapped[datetime] = mapped_column( + sa.DateTime, default=datetime.now() + ) + end_date: Mapped[datetime] = mapped_column( + sa.DateTime, default=datetime.now() + ) + price_per_night: Mapped[float] = mapped_column(sa.Float, default=0) + num_nights: Mapped[int] = mapped_column( + sa.Integer, nullable=False, default=1 + ) + user: Mapped[User] = sa.orm.relationship(User) + place: Mapped[Place] = sa.orm.relationship(Place) class Review(Base): __tablename__ = "review" __table_args__ = (UniqueConstraint("booking_id"),) - id = sa.Column(sa.Integer, primary_key=True) - booking_id = sa.Column(sa.Integer, sa.ForeignKey(Booking.id)) - rating = sa.Column(sa.SmallInteger, nullable=True) - review_body = sa.Column(sa.Text, nullable=True) - booking = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + booking_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Booking.id) + ) + rating: Mapped[int] = mapped_column(sa.SmallInteger, nullable=True) + review_body: Mapped[str] = mapped_column(sa.Text, nullable=True) + booking: Mapped[Booking] = sa.orm.relationship( Booking, backref=sa.orm.backref("booking"), ) diff --git a/examples/ancestry/schema.py b/examples/ancestry/schema.py index fd7a4d11..4d25f8fe 100644 --- a/examples/ancestry/schema.py +++ b/examples/ancestry/schema.py @@ -1,51 +1,69 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Parent(Base): __tablename__ = "parent" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) class Surrogate(Base): __tablename__ = "surrogate" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - parent_id = sa.Column(sa.Integer, sa.ForeignKey(Parent.id)) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + parent_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Parent.id) + ) class Child(Base): __tablename__ = "child" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - parent_id = sa.Column(sa.Integer, sa.ForeignKey(Surrogate.id)) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + parent_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Surrogate.id) + ) class GrandChild(Base): __tablename__ = "grand_child" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - parent_id = sa.Column(sa.Integer, sa.ForeignKey(Child.id)) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + parent_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Child.id)) class GreatGrandChild(Base): __tablename__ = "great_grand_child" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - parent_id = sa.Column(sa.Integer, sa.ForeignKey(GrandChild.id)) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + parent_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(GrandChild.id) + ) def setup(config: str) -> None: diff --git a/examples/book/schema.py b/examples/book/schema.py index 31612b56..3b49c0d3 100644 --- a/examples/book/schema.py +++ b/examples/book/schema.py @@ -1,6 +1,8 @@ +from datetime import datetime + import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, create_schema, pg_engine @@ -8,25 +10,31 @@ from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Continent(Base): __tablename__ = "continent" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) class Country(Base): __tablename__ = "country" __table_args__ = (UniqueConstraint("name", "continent_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) - continent_id = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + continent_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Continent.id, ondelete="CASCADE") ) - continent = sa.orm.relationship( + continent: Mapped[Continent] = sa.orm.relationship( Continent, backref=sa.orm.backref("continents") ) @@ -34,13 +42,15 @@ class Country(Base): class City(Base): __tablename__ = "city" __table_args__ = (UniqueConstraint("name", "country_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) - country_id = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + country_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Country.id, ondelete="CASCADE"), ) - country = sa.orm.relationship( + country: Mapped[Country] = sa.orm.relationship( Country, backref=sa.orm.backref("countries"), ) @@ -49,19 +59,25 @@ class City(Base): class Publisher(Base): __tablename__ = "publisher" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) - is_active = sa.Column(sa.Boolean, default=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + is_active: Mapped[bool] = mapped_column(sa.Boolean, default=False) class Author(Base): __tablename__ = "author" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) - date_of_birth = sa.Column(sa.DateTime, nullable=True) - city_id = sa.Column(sa.Integer, sa.ForeignKey(City.id, ondelete="CASCADE")) - city = sa.orm.relationship( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + date_of_birth: Mapped[datetime] = mapped_column(sa.DateTime, nullable=True) + city_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(City.id, ondelete="CASCADE") + ) + city: Mapped[City] = sa.orm.relationship( City, backref=sa.orm.backref("city"), ) @@ -70,70 +86,88 @@ class Author(Base): class Shelf(Base): __tablename__ = "shelf" __table_args__ = (UniqueConstraint("shelf"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - shelf = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + shelf: Mapped[str] = mapped_column(sa.String, nullable=False) class Subject(Base): __tablename__ = "subject" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) class Language(Base): __tablename__ = "language" __table_args__ = (UniqueConstraint("code"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - code = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + code: Mapped[str] = mapped_column(sa.String, nullable=False) class Book(Base): __tablename__ = "book" __table_args__ = (UniqueConstraint("isbn"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - isbn = sa.Column(sa.String, nullable=False) - title = sa.Column(sa.String, nullable=False) - description = sa.Column(sa.String, nullable=True) - copyright = sa.Column(sa.String, nullable=True) - tags = sa.Column(sa.dialects.postgresql.JSONB, nullable=True) - doc = sa.Column(sa.dialects.postgresql.JSONB, nullable=True) - publisher_id = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + isbn: Mapped[str] = mapped_column(sa.String, nullable=False) + title: Mapped[str] = mapped_column(sa.String, nullable=False) + description: Mapped[str] = mapped_column(sa.String, nullable=True) + copyright: Mapped[str] = mapped_column(sa.String, nullable=True) + tags: Mapped[dict] = mapped_column( + sa.dialects.postgresql.JSONB, nullable=True + ) + doc: Mapped[dict] = mapped_column( + sa.dialects.postgresql.JSONB, nullable=True + ) + publisher_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Publisher.id, ondelete="CASCADE") ) - publisher = sa.orm.relationship( + publisher: Mapped[Publisher] = sa.orm.relationship( Publisher, backref=sa.orm.backref("publishers"), ) - publish_date = sa.Column(sa.DateTime, nullable=True) + publish_date: Mapped[datetime] = mapped_column(sa.DateTime, nullable=True) class Rating(Base): __tablename__ = "rating" __table_args__ = (UniqueConstraint("book_isbn"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - book_isbn = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + book_isbn: Mapped[str] = mapped_column( sa.String, sa.ForeignKey(Book.isbn, ondelete="CASCADE") ) - book = sa.orm.relationship(Book, backref=sa.orm.backref("ratings")) - value = sa.Column(sa.Float, nullable=True) + book: Mapped[Book] = sa.orm.relationship( + Book, backref=sa.orm.backref("ratings") + ) + value: Mapped[float] = mapped_column(sa.Float, nullable=True) class BookAuthor(Base): __tablename__ = "book_author" __table_args__ = (UniqueConstraint("book_isbn", "author_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - book_isbn = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + book_isbn: Mapped[str] = mapped_column( sa.String, sa.ForeignKey(Book.isbn, ondelete="CASCADE") ) - book = sa.orm.relationship( + book: Mapped[Book] = sa.orm.relationship( Book, backref=sa.orm.backref("book_author_books"), ) - author_id = sa.Column( + author_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Author.id, ondelete="CASCADE") ) - author = sa.orm.relationship( + author: Mapped[Author] = sa.orm.relationship( Author, backref=sa.orm.backref("authors"), ) @@ -142,18 +176,20 @@ class BookAuthor(Base): class BookSubject(Base): __tablename__ = "book_subject" __table_args__ = (UniqueConstraint("book_isbn", "subject_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - book_isbn = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + book_isbn: Mapped[str] = mapped_column( sa.String, sa.ForeignKey(Book.isbn, ondelete="CASCADE") ) - book = sa.orm.relationship( + book: Mapped[Book] = sa.orm.relationship( Book, backref=sa.orm.backref("book_subject_books"), ) - subject_id = sa.Column( + subject_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Subject.id, ondelete="CASCADE") ) - subject = sa.orm.relationship( + subject: Mapped[Subject] = sa.orm.relationship( Subject, backref=sa.orm.backref("subjects"), ) @@ -163,18 +199,20 @@ class BookLanguage(Base): __tablename__ = "book_language" __table_args__ = (UniqueConstraint("book_isbn", "language_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - book_isbn = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + book_isbn: Mapped[str] = mapped_column( sa.String, sa.ForeignKey(Book.isbn, ondelete="CASCADE") ) - book = sa.orm.relationship( + book: Mapped[Book] = sa.orm.relationship( Book, backref=sa.orm.backref("book_language_books"), ) - language_id = sa.Column( + language_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Language.id, ondelete="CASCADE") ) - language = sa.orm.relationship( + language: Mapped[Language] = sa.orm.relationship( Language, backref=sa.orm.backref("languages"), ) @@ -184,18 +222,22 @@ class BookShelf(Base): __tablename__ = "bookshelf" __table_args__ = (UniqueConstraint("book_isbn", "shelf_id"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - book_isbn = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + book_isbn: Mapped[str] = mapped_column( sa.String, sa.ForeignKey(Book.isbn, ondelete="CASCADE") ) - book = sa.orm.relationship( + book: Mapped[Book] = sa.orm.relationship( Book, backref=sa.orm.backref("book_bookshelf_books"), ) - shelf_id = sa.Column( + shelf_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Shelf.id, ondelete="CASCADE") ) - shelf = sa.orm.relationship(Shelf, backref=sa.orm.backref("shelves")) + shelf: Mapped[Shelf] = sa.orm.relationship( + Shelf, backref=sa.orm.backref("shelves") + ) def setup(config: str) -> None: @@ -205,9 +247,7 @@ def setup(config: str) -> None: create_database(database) create_schema(database, schema) with pg_engine(database) as engine: - engine = engine.connect().execution_options( - schema_translate_map={None: schema} - ) + Base.metadata.schema = schema Base.metadata.drop_all(engine) Base.metadata.create_all(engine) diff --git a/examples/book_view/schema.py b/examples/book_view/schema.py index ce68cead..bce00a8f 100644 --- a/examples/book_view/schema.py +++ b/examples/book_view/schema.py @@ -1,6 +1,6 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, create_schema, pg_engine @@ -9,29 +9,35 @@ from pgsync.utils import config_loader, get_config from pgsync.view import CreateView -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Publisher(Base): __tablename__ = "publisher" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) - is_active = sa.Column(sa.Boolean, default=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + is_active: Mapped[bool] = mapped_column(sa.Boolean, default=False) class Book(Base): __tablename__ = "book" __table_args__ = (UniqueConstraint("isbn"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - isbn = sa.Column(sa.String, nullable=False) - title = sa.Column(sa.String, nullable=False) - description = sa.Column(sa.String, nullable=True) - copyright = sa.Column(sa.String, nullable=True) - publisher_id = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + isbn: Mapped[str] = mapped_column(sa.String, nullable=False) + title: Mapped[str] = mapped_column(sa.String, nullable=False) + description: Mapped[str] = mapped_column(sa.String, nullable=True) + copyright: Mapped[str] = mapped_column(sa.String, nullable=True) + publisher_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Publisher.id, ondelete="CASCADE") ) - publisher = sa.orm.relationship( + publisher: Mapped[Publisher] = sa.orm.relationship( Publisher, backref=sa.orm.backref("publishers"), ) @@ -44,9 +50,7 @@ def setup(config: str) -> None: create_database(database) create_schema(database, schema) with pg_engine(database) as engine: - engine = engine.connect().execution_options( - schema_translate_map={None: schema} - ) + Base.metadata.schema = schema Base.metadata.drop_all(engine) Base.metadata.create_all(engine) @@ -54,22 +58,28 @@ def setup(config: str) -> None: metadata.reflect(engine, views=True) book_model = metadata.tables[f"{schema}.book"] - engine.execute( - CreateView( - schema, - "book_view", - book_model.select(), + + with engine.connect() as conn: + conn.execute( + CreateView( + schema, + "book_view", + book_model.select(), + ) ) - ) + conn.commit() publisher_model = metadata.tables[f"{schema}.publisher"] - engine.execute( - CreateView( - schema, - "publisher_view", - publisher_model.select(), + + with engine.connect() as conn: + conn.execute( + CreateView( + schema, + "publisher_view", + publisher_model.select(), + ) ) - ) + conn.commit() @click.command() diff --git a/examples/node/README b/examples/node/README index 3d5ca7ed..77fff1a7 100644 --- a/examples/node/README +++ b/examples/node/README @@ -1,3 +1,3 @@ Demonstrates Adjacency List Relationships -- https://docs.sqlalchemy.org/en/14/orm/self_referential.html \ No newline at end of file +- https://docs.sqlalchemy.org/en/20/orm/self_referential.html \ No newline at end of file diff --git a/examples/node/schema.py b/examples/node/schema.py index 46e57f4a..d6662fb4 100644 --- a/examples/node/schema.py +++ b/examples/node/schema.py @@ -1,20 +1,26 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Node(Base): __tablename__ = "node" - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - node_id = sa.Column(sa.Integer, sa.ForeignKey("node.id")) - children = sa.orm.relationship("Node", lazy="joined", join_depth=2) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + node_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey("node.id")) + children: Mapped["Node"] = sa.orm.relationship( + "Node", lazy="joined", join_depth=2 + ) def setup(config: str) -> None: diff --git a/examples/quiz/schema.py b/examples/quiz/schema.py index 0a6b4f05..efdf43ee 100644 --- a/examples/quiz/schema.py +++ b/examples/quiz/schema.py @@ -1,21 +1,23 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import ForeignKeyConstraint, UniqueConstraint from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Category(Base): __tablename__ = "category" __table_args__ = (UniqueConstraint("text"),) - id = sa.Column(sa.Integer, primary_key=True) - uid = sa.Column(sa.String, primary_key=True) - text = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + uid: Mapped[str] = mapped_column(sa.String, primary_key=True) + text: Mapped[str] = mapped_column(sa.String, nullable=False) class Question(Base): @@ -26,19 +28,19 @@ class Question(Base): ["category_id", "category_uid"], ["category.id", "category.uid"] ), ) - id = sa.Column(sa.Integer, primary_key=True) - uid = sa.Column(sa.String, primary_key=True) - category_id = sa.Column(sa.Integer) - category_uid = sa.Column(sa.String) - text = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + uid: Mapped[str] = mapped_column(sa.String, primary_key=True) + category_id: Mapped[int] = mapped_column(sa.Integer) + category_uid: Mapped[str] = mapped_column(sa.String) + text: Mapped[str] = mapped_column(sa.String, nullable=False) class Answer(Base): __tablename__ = "answer" __table_args__ = (UniqueConstraint("text"),) - id = sa.Column(sa.Integer, primary_key=True) - uid = sa.Column(sa.String, primary_key=True) - text = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + uid: Mapped[str] = mapped_column(sa.String, primary_key=True) + text: Mapped[str] = mapped_column(sa.String, nullable=False) class PossibleAnswer(Base): @@ -59,11 +61,13 @@ class PossibleAnswer(Base): ["question.id", "question.uid"], ), ) - question_id = sa.Column(sa.Integer, primary_key=True) - question_uid = sa.Column(sa.String, primary_key=True) - answer_id = sa.Column(sa.Integer, primary_key=True) - answer_uid = sa.Column(sa.String, primary_key=True) - answer = sa.orm.relationship(Answer, backref=sa.orm.backref("answer")) + question_id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + question_uid: Mapped[str] = mapped_column(sa.String, primary_key=True) + answer_id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + answer_uid: Mapped[str] = mapped_column(sa.String, primary_key=True) + answer: Mapped[Answer] = sa.orm.relationship( + Answer, backref=sa.orm.backref("answer") + ) class RealAnswer(Base): @@ -84,19 +88,19 @@ class RealAnswer(Base): ["question.id", "question.uid"], ), ) - question_id = sa.Column( + question_id: Mapped[int] = mapped_column( sa.Integer, primary_key=True, ) - question_uid = sa.Column( + question_uid: Mapped[str] = mapped_column( sa.String, primary_key=True, ) - answer_id = sa.Column( + answer_id: Mapped[int] = mapped_column( sa.Integer, primary_key=True, ) - answer_uid = sa.Column( + answer_uid: Mapped[str] = mapped_column( sa.String, primary_key=True, ) diff --git a/examples/schemas/schema.py b/examples/schemas/schema.py index 636e84c5..9731fdf3 100644 --- a/examples/schemas/schema.py +++ b/examples/schemas/schema.py @@ -1,27 +1,35 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from pgsync.base import create_database, create_schema, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Parent(Base): __tablename__ = "parent" __table_args__ = {"schema": "parent"} - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) class Child(Base): __tablename__ = "child" __table_args__ = {"schema": "child"} - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String) - parent_id = sa.Column(sa.Integer, sa.ForeignKey(Parent.id)) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String) + parent_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Parent.id) + ) def setup(config: str) -> None: @@ -32,16 +40,12 @@ def setup(config: str) -> None: create_schema(database, schema) with pg_engine(database) as engine: - engine: sa.engine.Engine = engine.connect().execution_options( - schema_translate_map={None: "parent"} - ) + Base.metadata.schema = "parent" Base.metadata.drop_all(engine) Base.metadata.create_all(engine) with pg_engine(database) as engine: - engine: sa.engine.Engine = engine.connect().execution_options( - schema_translate_map={None: "child"} - ) + Base.metadata.schema = "child" Base.metadata.drop_all(engine) Base.metadata.create_all(engine) diff --git a/examples/social/schema.py b/examples/social/schema.py index 89c2118f..ca26db59 100644 --- a/examples/social/schema.py +++ b/examples/social/schema.py @@ -1,58 +1,60 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class User(Base): __tablename__ = "user" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) - age = sa.Column(sa.Integer, nullable=True) - gender = sa.Column(sa.String, nullable=True) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + age: Mapped[int] = mapped_column(sa.Integer, nullable=True) + gender: Mapped[str] = mapped_column(sa.String, nullable=True) class Post(Base): __tablename__ = "post" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True) - title = sa.Column(sa.String, nullable=False) - slug = sa.Column(sa.String, nullable=True) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + title: Mapped[str] = mapped_column(sa.String, nullable=False) + slug: Mapped[str] = mapped_column(sa.String, nullable=True) class Comment(Base): __tablename__ = "comment" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True) - title = sa.Column(sa.String, nullable=True) - content = sa.Column(sa.String, nullable=True) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + title: Mapped[str] = mapped_column(sa.String, nullable=True) + content: Mapped[str] = mapped_column(sa.String, nullable=True) class Tag(Base): __tablename__ = "tag" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) class UserPost(Base): __tablename__ = "user_post" __table_args__ = () - id = sa.Column(sa.Integer, primary_key=True) - user_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) - user = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(User.id)) + user: Mapped[User] = sa.orm.relationship( User, backref=sa.orm.backref("users"), ) - post_id = sa.Column(sa.Integer, sa.ForeignKey(Post.id)) - post = sa.orm.relationship( + post_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Post.id)) + post: Mapped[Post] = sa.orm.relationship( Post, backref=sa.orm.backref("posts"), ) @@ -61,27 +63,31 @@ class UserPost(Base): class PostComment(Base): __tablename__ = "post_comment" __table_args__ = (UniqueConstraint("post_id", "comment_id"),) - id = sa.Column(sa.Integer, primary_key=True) - post_id = sa.Column(sa.Integer, sa.ForeignKey(Post.id)) - post = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + post_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Post.id)) + post: Mapped[Post] = sa.orm.relationship( Post, backref=sa.orm.backref("post"), ) - comment_id = sa.Column(sa.Integer, sa.ForeignKey(Comment.id)) - comment = sa.orm.relationship(Comment, backref=sa.orm.backref("comments")) + comment_id: Mapped[int] = mapped_column( + sa.Integer, sa.ForeignKey(Comment.id) + ) + comment: Mapped[Comment] = sa.orm.relationship( + Comment, backref=sa.orm.backref("comments") + ) class UserTag(Base): __tablename__ = "user_tag" __table_args__ = (UniqueConstraint("user_id", "tag_id"),) - id = sa.Column(sa.Integer, primary_key=True) - user_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) - user = sa.orm.relationship( + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + user_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(User.id)) + user: Mapped[User] = sa.orm.relationship( User, backref=sa.orm.backref("user"), ) - tag_id = sa.Column(sa.Integer, sa.ForeignKey(Tag.id)) - tag = sa.orm.relationship( + tag_id: Mapped[int] = mapped_column(sa.Integer, sa.ForeignKey(Tag.id)) + tag: Mapped[Tag] = sa.orm.relationship( Tag, backref=sa.orm.backref("tags"), ) diff --git a/examples/starcraft/schema.py b/examples/starcraft/schema.py index ed855f89..3f9972af 100644 --- a/examples/starcraft/schema.py +++ b/examples/starcraft/schema.py @@ -1,13 +1,16 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, pg_engine from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass + # sourced from https://starcraft.fandom.com/wiki/List_of_StarCraft_II_units @@ -15,8 +18,8 @@ class Specie(Base): __tablename__ = "specie" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) class Unit(Base): @@ -26,10 +29,10 @@ class Unit(Base): "name", ), ) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) - details = sa.Column(sa.String, nullable=True) - specie_id = sa.Column(sa.Integer, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + details: Mapped[str] = mapped_column(sa.String, nullable=True) + specie_id: Mapped[int] = mapped_column(sa.Integer, nullable=False) class Structure(Base): @@ -39,10 +42,10 @@ class Structure(Base): "name", ), ) - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.String, nullable=False) - details = sa.Column(sa.String, nullable=True) - specie_id = sa.Column(sa.Integer, nullable=False) + id: Mapped[int] = mapped_column(sa.Integer, primary_key=True) + name: Mapped[str] = mapped_column(sa.String, nullable=False) + details: Mapped[str] = mapped_column(sa.String, nullable=True) + specie_id: Mapped[int] = mapped_column(sa.Integer, nullable=False) def setup(config: str) -> None: diff --git a/examples/through/schema.py b/examples/through/schema.py index 7338555e..098e21c4 100644 --- a/examples/through/schema.py +++ b/examples/through/schema.py @@ -1,6 +1,6 @@ import click import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column from sqlalchemy.schema import UniqueConstraint from pgsync.base import create_database, create_schema, pg_engine @@ -8,14 +8,18 @@ from pgsync.helper import teardown from pgsync.utils import config_loader, get_config -Base = declarative_base() + +class Base(DeclarativeBase): + pass class Customer(Base): __tablename__ = "customer" __table_args__ = (UniqueConstraint("name"),) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + name: Mapped[str] = mapped_column(sa.String, nullable=False) class Group(Base): @@ -25,8 +29,10 @@ class Group(Base): "group_name", ), ) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - group_name = sa.Column(sa.String, nullable=False) + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + group_name: Mapped[str] = mapped_column(sa.String, nullable=False) class CustomerGroup(Base): @@ -37,20 +43,22 @@ class CustomerGroup(Base): "group_id", ), ) - id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - customer_id = sa.Column( + id: Mapped[int] = mapped_column( + sa.Integer, primary_key=True, autoincrement=True + ) + customer_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Customer.id, ondelete="CASCADE"), ) - customer = sa.orm.relationship( + customer: Mapped[Customer] = sa.orm.relationship( Customer, backref=sa.orm.backref("customers"), ) - group_id = sa.Column( + group_id: Mapped[int] = mapped_column( sa.Integer, sa.ForeignKey(Group.id, ondelete="CASCADE"), ) - group = sa.orm.relationship( + group: Mapped[Group] = sa.orm.relationship( Group, backref=sa.orm.backref("groups"), ) @@ -63,9 +71,7 @@ def setup(config: str) -> None: create_database(database) create_schema(database, schema) with pg_engine(database) as engine: - engine = engine.connect().execution_options( - schema_translate_map={None: schema} - ) + Base.metadata.schema = schema Base.metadata.drop_all(engine) Base.metadata.create_all(engine) diff --git a/pgsync/base.py b/pgsync/base.py index e51f92e1..f65ab9ed 100644 --- a/pgsync/base.py +++ b/pgsync/base.py @@ -172,7 +172,9 @@ def connect(self) -> None: def pg_settings(self, column: str) -> Optional[str]: try: return self.fetchone( - sa.select([sa.column("setting")]) + sa.select( + sa.text("setting"), + ) .select_from(sa.text("pg_settings")) .where(sa.column("name") == column), label="pg_settings", @@ -188,12 +190,13 @@ def _can_create_replication_slot(self, slot_name: str) -> None: try: self.create_replication_slot(slot_name) + except Exception as e: logger.exception(f"{e}") raise ReplicationSlotError( f'PG_USER "{self.engine.url.username}" needs to be ' f"superuser or have permission to read, create and destroy " - f"replication slots to perform this action." + f"replication slots to perform this action.\n{e}" ) else: self.drop_replication_slot(slot_name) @@ -278,6 +281,8 @@ def _views(self, schema: str) -> list: if schema not in self.__views: self.__views[schema] = [] for table in sa.inspect(self.engine).get_view_names(schema): + # TODO: figure out why we need is_view here when sqlalchemy + # already reflects views if is_view(self.engine, schema, table, materialized=False): self.__views[schema].append(table) return self.__views[schema] @@ -286,7 +291,11 @@ def _materialized_views(self, schema: str) -> list: """Get all materialized views.""" if schema not in self.__materialized_views: self.__materialized_views[schema] = [] - for table in sa.inspect(self.engine).get_view_names(schema): + for table in sa.inspect(self.engine).get_materialized_view_names( + schema + ): + # TODO: figure out why we need is_view here when sqlalchemy + # already reflects views if is_view(self.engine, schema, table, materialized=True): self.__materialized_views[schema].append(table) return self.__materialized_views[schema] @@ -330,7 +339,7 @@ def truncate_table(self, table: str, schema: str = DEFAULT_SCHEMA) -> None: """ logger.debug(f"Truncating table: {schema}.{table}") - self.execute(sa.DDL(f'TRUNCATE TABLE "{schema}"."{table}" CASCADE')) + self.execute(sa.text(f'TRUNCATE TABLE "{schema}"."{table}" CASCADE')) def truncate_tables( self, tables: List[str], schema: str = DEFAULT_SCHEMA @@ -362,7 +371,7 @@ def replication_slots( SELECT * FROM PG_REPLICATION_SLOTS """ return self.fetchall( - sa.select(["*"]) + sa.select("*") .select_from(sa.text("PG_REPLICATION_SLOTS")) .where( sa.and_( @@ -386,26 +395,28 @@ def create_replication_slot(self, slot_name: str) -> None: SELECT * FROM PG_REPLICATION_SLOTS """ logger.debug(f"Creating replication slot: {slot_name}") - return self.fetchone( - sa.select(["*"]).select_from( - sa.func.PG_CREATE_LOGICAL_REPLICATION_SLOT( - slot_name, - PLUGIN, + try: + self.execute( + sa.select("*").select_from( + sa.func.PG_CREATE_LOGICAL_REPLICATION_SLOT( + slot_name, + PLUGIN, + ) ) - ), - label="create_replication_slot", - ) + ) + except Exception as e: + logger.exception(f"{e}") + raise def drop_replication_slot(self, slot_name: str) -> None: """Drop a replication slot.""" logger.debug(f"Dropping replication slot: {slot_name}") if self.replication_slots(slot_name): try: - return self.fetchone( - sa.select(["*"]).select_from( + self.execute( + sa.select("*").select_from( sa.func.PG_DROP_REPLICATION_SLOT(slot_name), - ), - label="drop_replication_slot", + ) ) except Exception as e: logger.exception(f"{e}") @@ -440,7 +451,8 @@ def _logical_slot_changes( """ filters: list = [] statement: sa.sql.Select = sa.select( - [sa.column("xid"), sa.column("data")] + sa.text("xid"), + sa.text("data"), ).select_from( func( slot_name, @@ -546,7 +558,7 @@ def logical_slot_count_changes( ) with self.engine.connect() as conn: return conn.execute( - statement.with_only_columns([sa.func.COUNT()]) + statement.with_only_columns(*[sa.func.COUNT()]) ).scalar() # Views... @@ -571,7 +583,8 @@ def create_view( def drop_view(self, schema: str) -> None: """Drop a view.""" logger.debug(f"Dropping view: {schema}.{MATERIALIZED_VIEW}") - self.engine.execute(DropView(schema, MATERIALIZED_VIEW)) + with self.engine.connect() as conn: + conn.execute(DropView(schema, MATERIALIZED_VIEW)) logger.debug(f"Dropped view: {schema}.{MATERIALIZED_VIEW}") def refresh_view( @@ -579,9 +592,8 @@ def refresh_view( ) -> None: """Refresh a materialized view.""" logger.debug(f"Refreshing view: {schema}.{name}") - self.engine.execute( - RefreshView(schema, name, concurrently=concurrently) - ) + with self.engine.connect() as conn: + conn.execute(RefreshView(schema, name, concurrently=concurrently)) logger.debug(f"Refreshed view: {schema}.{name}") # Triggers... @@ -612,10 +624,10 @@ def create_triggers( ) if join_queries: if queries: - self.execute(sa.DDL("; ".join(queries))) + self.execute(sa.text("; ".join(queries))) else: for query in queries: - self.execute(sa.DDL(query)) + self.execute(sa.text(query)) def drop_triggers( self, @@ -636,25 +648,27 @@ def drop_triggers( ) if join_queries: if queries: - self.execute(sa.DDL("; ".join(queries))) + self.execute(sa.text("; ".join(queries))) else: for query in queries: - self.execute(sa.DDL(query)) + self.execute(sa.text(query)) def create_function(self, schema: str) -> None: self.execute( - CREATE_TRIGGER_TEMPLATE.replace( - MATERIALIZED_VIEW, - f"{schema}.{MATERIALIZED_VIEW}", - ).replace( - TRIGGER_FUNC, - f"{schema}.{TRIGGER_FUNC}", + sa.text( + CREATE_TRIGGER_TEMPLATE.replace( + MATERIALIZED_VIEW, + f"{schema}.{MATERIALIZED_VIEW}", + ).replace( + TRIGGER_FUNC, + f"{schema}.{TRIGGER_FUNC}", + ) ) ) def drop_function(self, schema: str) -> None: self.execute( - sa.DDL( + sa.text( f'DROP FUNCTION IF EXISTS "{schema}".{TRIGGER_FUNC}() CASCADE' ) ) @@ -665,7 +679,7 @@ def disable_triggers(self, schema: str) -> None: logger.debug(f"Disabling trigger on table: {schema}.{table}") for name in ("notify", "truncate"): self.execute( - sa.DDL( + sa.text( f'ALTER TABLE "{schema}"."{table}" ' f"DISABLE TRIGGER {table}_{name}" ) @@ -677,7 +691,7 @@ def enable_triggers(self, schema: str) -> None: logger.debug(f"Enabling trigger on table: {schema}.{table}") for name in ("notify", "truncate"): self.execute( - sa.DDL( + sa.text( f'ALTER TABLE "{schema}"."{table}" ' f"ENABLE TRIGGER {table}_{name}" ) @@ -691,7 +705,7 @@ def txid_current(self) -> int: SELECT txid_current() """ return self.fetchone( - sa.select(["*"]).select_from(sa.func.TXID_CURRENT()), + sa.select("*").select_from(sa.func.TXID_CURRENT()), label="txid_current", )[0] @@ -827,14 +841,8 @@ def fetchone( if self.verbose: compiled_query(statement, label=label, literal_binds=literal_binds) - conn = self.engine.connect() - try: - row = conn.execute(statement).fetchone() - conn.close() - except Exception as e: - logger.exception(f"Exception {e}") - raise - return row + with self.engine.connect() as conn: + return conn.execute(statement).fetchone() def fetchall( self, @@ -846,14 +854,8 @@ def fetchall( if self.verbose: compiled_query(statement, label=label, literal_binds=literal_binds) - conn = self.engine.connect() - try: - rows = conn.execute(statement).fetchall() - conn.close() - except Exception as e: - logger.exception(f"Exception {e}") - raise - return rows + with self.engine.connect() as conn: + return conn.execute(statement).fetchall() def fetchmany( self, @@ -877,7 +879,7 @@ def fetchcount(self, statement: sa.sql.Subquery) -> int: with self.engine.connect() as conn: return conn.execute( statement.original.with_only_columns( - [sa.func.COUNT()] + *[sa.func.COUNT()] ).order_by(None) ).scalar() @@ -1018,23 +1020,18 @@ def pg_execute( values: Optional[list] = None, options: Optional[dict] = None, ) -> None: - options = options or {"isolation_level": "AUTOCOMMIT"} - conn = engine.connect() - try: + with engine.connect() as conn: if options: conn = conn.execution_options(**options) conn.execute(statement, values) - conn.close() - except Exception as e: - logger.exception(f"Exception {e}") - raise + conn.commit() def create_schema(database: str, schema: str, echo: bool = False) -> None: """Create database schema.""" logger.debug(f"Creating schema: {schema}") with pg_engine(database, echo=echo) as engine: - pg_execute(engine, sa.DDL(f"CREATE SCHEMA IF NOT EXISTS {schema}")) + pg_execute(engine, sa.text(f"CREATE SCHEMA IF NOT EXISTS {schema}")) logger.debug(f"Created schema: {schema}") @@ -1042,7 +1039,11 @@ def create_database(database: str, echo: bool = False) -> None: """Create a database.""" logger.debug(f"Creating database: {database}") with pg_engine("postgres", echo=echo) as engine: - pg_execute(engine, sa.DDL(f'CREATE DATABASE "{database}"')) + pg_execute( + engine, + sa.text(f'CREATE DATABASE "{database}"'), + options={"isolation_level": "AUTOCOMMIT"}, + ) logger.debug(f"Created database: {database}") @@ -1050,24 +1051,26 @@ def drop_database(database: str, echo: bool = False) -> None: """Drop a database.""" logger.debug(f"Dropping database: {database}") with pg_engine("postgres", echo=echo) as engine: - pg_execute(engine, sa.DDL(f'DROP DATABASE IF EXISTS "{database}"')) + pg_execute( + engine, + sa.text(f'DROP DATABASE IF EXISTS "{database}"'), + options={"isolation_level": "AUTOCOMMIT"}, + ) + logger.debug(f"Dropped database: {database}") def database_exists(database: str, echo: bool = False) -> bool: """Check if database is present.""" with pg_engine("postgres", echo=echo) as engine: - conn = engine.connect() - try: + with engine.connect() as conn: row = conn.execute( - sa.DDL( - f"SELECT 1 FROM pg_database WHERE datname = '{database}'" + sa.select( + sa.text("1"), ) - ).first() - conn.close() - except Exception as e: - logger.exception(f"Exception {e}") - raise + .select_from(sa.text("pg_database")) + .where(sa.column("datname") == database), + ).fetchone() return row is not None @@ -1079,7 +1082,7 @@ def create_extension( with pg_engine(database, echo=echo) as engine: pg_execute( engine, - sa.DDL(f'CREATE EXTENSION IF NOT EXISTS "{extension}"'), + sa.text(f'CREATE EXTENSION IF NOT EXISTS "{extension}"'), ) logger.debug(f"Created extension: {extension}") @@ -1088,5 +1091,5 @@ def drop_extension(database: str, extension: str, echo: bool = False) -> None: """Drop a database extension.""" logger.debug(f"Dropping extension: {extension}") with pg_engine(database, echo=echo) as engine: - pg_execute(engine, sa.DDL(f'DROP EXTENSION IF EXISTS "{extension}"')) + pg_execute(engine, sa.text(f'DROP EXTENSION IF EXISTS "{extension}"')) logger.debug(f"Dropped extension: {extension}") diff --git a/pgsync/querybuilder.py b/pgsync/querybuilder.py index 07d920d3..3df032ee 100644 --- a/pgsync/querybuilder.py +++ b/pgsync/querybuilder.py @@ -263,7 +263,7 @@ def _root( self._json_build_object(node.columns), *node.primary_keys, ] - node._subquery = sa.select(columns) + node._subquery = sa.select(*columns) if self.from_obj is not None: node._subquery = node._subquery.select_from(self.from_obj) @@ -273,7 +273,7 @@ def _root( for page, rows in ctid.items(): subquery.append( sa.select( - [ + *[ sa.cast( sa.literal_column(f"'({page},'") .concat(sa.column("s")) @@ -585,7 +585,7 @@ def _through(self, node: Node) -> None: # noqa: C901 isouter=self.isouter, ) - outer_subquery = sa.select(columns) + outer_subquery = sa.select(*columns) parent_foreign_key_columns: list = self._get_column_foreign_keys( through.columns, @@ -654,7 +654,7 @@ def _through(self, node: Node) -> None: # noqa: C901 for column in foreign_keys[through.name]: columns.append(through.model.c[str(column)]) - inner_subquery = sa.select(columns) + inner_subquery = sa.select(*columns) if self.verbose: compiled_query(inner_subquery, "Inner subquery") @@ -823,7 +823,7 @@ def _non_through(self, node: Node) -> None: # noqa: C901 for column in foreign_key_columns: columns.append(node.model.c[column]) - node._subquery = sa.select(columns) + node._subquery = sa.select(*columns) if from_obj is not None: node._subquery = node._subquery.select_from(from_obj) diff --git a/pgsync/view.py b/pgsync/view.py index 3b2a9655..bc5c552a 100644 --- a/pgsync/view.py +++ b/pgsync/view.py @@ -254,7 +254,7 @@ def _get_constraints( key_column_usage = models("key_column_usage", "information_schema") return ( sa.select( - [ + *[ table_constraints.c.table_name, sa.func.ARRAY_AGG( sa.cast( @@ -383,7 +383,7 @@ def create_view( rows: dict = {} if MATERIALIZED_VIEW in views: for table_name, primary_keys, foreign_keys, indices in fetchall( - sa.select(["*"]).select_from( + sa.select("*").select_from( sa.text(f"{schema}.{MATERIALIZED_VIEW}") ) ): @@ -401,8 +401,9 @@ def create_view( rows[table_name]["foreign_keys"] = set(foreign_keys) if indices: rows[table_name]["indices"] = set(indices) - - engine.execute(DropView(schema, MATERIALIZED_VIEW)) + with engine.connect() as conn: + conn.execute(DropView(schema, MATERIALIZED_VIEW)) + conn.commit() if schema != DEFAULT_SCHEMA: for table in set(tables): @@ -473,16 +474,18 @@ def create_view( .alias("t") ) logger.debug(f"Creating view: {schema}.{MATERIALIZED_VIEW}") - engine.execute(CreateView(schema, MATERIALIZED_VIEW, statement)) - engine.execute(DropIndex("_idx")) - engine.execute( - CreateIndex( - "_idx", - schema, - MATERIALIZED_VIEW, - ["table_name"], + with engine.connect() as conn: + conn.execute(CreateView(schema, MATERIALIZED_VIEW, statement)) + conn.execute(DropIndex("_idx")) + conn.execute( + CreateIndex( + "_idx", + schema, + MATERIALIZED_VIEW, + ["table_name"], + ) ) - ) + conn.commit() logger.debug(f"Created view: {schema}.{MATERIALIZED_VIEW}") @@ -509,7 +512,7 @@ def is_view( with engine.connect() as conn: return ( conn.execute( - sa.select([sa.column(column)]) + sa.select(*[sa.column(column)]) .select_from(sa.text(pg_table)) .where( sa.and_( @@ -519,7 +522,7 @@ def is_view( ] ) ) - .with_only_columns([sa.func.COUNT()]) + .with_only_columns(*[sa.func.COUNT()]) .order_by(None) ).scalar() > 0 diff --git a/plugins/sample.py b/plugins/sample.py index 311e920c..b839e587 100644 --- a/plugins/sample.py +++ b/plugins/sample.py @@ -55,6 +55,7 @@ def transform(self, doc: dict, **kwargs) -> dict: if doc_id == "x": # do something... pass + if doc_index == "myindex": # do another thing... pass diff --git a/requirements/base.in b/requirements/base.in index e440725b..4fef4e57 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,18 +1,15 @@ # base libraries -black + boto3 click elasticsearch elasticsearch-dsl environs -faker isort opensearch-dsl psycopg2-binary python-dotenv redis requests-aws4auth +sqlalchemy sqlparse -# Pinned dependencies -# pin sqlalchemy to latest 1.4.* until 2.0 support -sqlalchemy==1.4.* diff --git a/requirements/base.txt b/requirements/base.txt index ff4a197d..6ddbf9fd 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -6,11 +6,9 @@ # async-timeout==4.0.3 # via redis -black==23.11.0 +boto3==1.33.8 # via -r requirements/base.in -boto3==1.29.6 - # via -r requirements/base.in -botocore==1.32.6 +botocore==1.33.8 # via # boto3 # s3transfer @@ -22,9 +20,7 @@ certifi==2023.11.17 charset-normalizer==3.3.2 # via requests click==8.1.7 - # via - # -r requirements/base.in - # black + # via -r requirements/base.in elastic-transport==8.10.0 # via elasticsearch elasticsearch==8.11.0 @@ -35,11 +31,9 @@ elasticsearch-dsl==8.11.0 # via -r requirements/base.in environs==9.5.0 # via -r requirements/base.in -faker==20.1.0 - # via -r requirements/base.in greenlet==3.0.1 # via sqlalchemy -idna==3.5 +idna==3.6 # via requests isort==5.12.0 # via -r requirements/base.in @@ -49,27 +43,18 @@ jmespath==1.0.1 # botocore marshmallow==3.20.1 # via environs -mypy-extensions==1.0.0 - # via black opensearch-dsl==2.1.0 # via -r requirements/base.in opensearch-py==2.4.2 # via opensearch-dsl packaging==23.2 - # via - # black - # marshmallow -pathspec==0.11.2 - # via black -platformdirs==4.0.0 - # via black + # via marshmallow psycopg2-binary==2.9.9 # via -r requirements/base.in python-dateutil==2.8.2 # via # botocore # elasticsearch-dsl - # faker # opensearch-dsl # opensearch-py python-dotenv==1.0.0 @@ -84,7 +69,7 @@ requests==2.31.0 # requests-aws4auth requests-aws4auth==1.2.3 # via -r requirements/base.in -s3transfer==0.7.0 +s3transfer==0.8.2 # via boto3 six==1.16.0 # via @@ -92,14 +77,12 @@ six==1.16.0 # opensearch-py # python-dateutil # requests-aws4auth -sqlalchemy==1.4.50 +sqlalchemy==2.0.23 # via -r requirements/base.in sqlparse==0.4.4 # via -r requirements/base.in -tomli==2.0.1 - # via black typing-extensions==4.8.0 - # via black + # via sqlalchemy urllib3==1.26.18 # via # botocore diff --git a/requirements/dev.in b/requirements/dev.in index aa8d8d10..a057e72b 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -1,6 +1,8 @@ -r base.in +black coverage +faker flake8 freezegun mock diff --git a/requirements/dev.txt b/requirements/dev.txt index c34f6bff..48a76647 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -7,10 +7,10 @@ async-timeout==4.0.3 # via redis black==23.11.0 + # via -r requirements/dev.in +boto3==1.33.8 # via -r requirements/base.in -boto3==1.29.6 - # via -r requirements/base.in -botocore==1.32.6 +botocore==1.33.8 # via # boto3 # s3transfer @@ -46,18 +46,18 @@ environs==9.5.0 exceptiongroup==1.2.0 # via pytest faker==20.1.0 - # via -r requirements/base.in + # via -r requirements/dev.in filelock==3.13.1 # via virtualenv flake8==6.1.0 # via -r requirements/dev.in -freezegun==1.2.2 +freezegun==1.3.1 # via -r requirements/dev.in greenlet==3.0.1 # via sqlalchemy identify==2.5.32 # via pre-commit -idna==3.5 +idna==3.6 # via requests iniconfig==2.0.0 # via pytest @@ -88,7 +88,7 @@ packaging==23.2 # pytest pathspec==0.11.2 # via black -platformdirs==4.0.0 +platformdirs==4.1.0 # via # black # virtualenv @@ -133,7 +133,7 @@ requests==2.31.0 # requests-aws4auth requests-aws4auth==1.2.3 # via -r requirements/base.in -s3transfer==0.7.0 +s3transfer==0.8.2 # via boto3 six==1.16.0 # via @@ -141,7 +141,7 @@ six==1.16.0 # opensearch-py # python-dateutil # requests-aws4auth -sqlalchemy==1.4.50 +sqlalchemy==2.0.23 # via -r requirements/base.in sqlparse==0.4.4 # via -r requirements/base.in @@ -151,14 +151,16 @@ tomli==2.0.1 # coverage # pytest typing-extensions==4.8.0 - # via black + # via + # black + # sqlalchemy urllib3==1.26.18 # via # botocore # elastic-transport # opensearch-py # requests -virtualenv==20.24.7 +virtualenv==20.25.0 # via pre-commit # The following packages are considered to be unsafe in a requirements file: diff --git a/tests/conftest.py b/tests/conftest.py index 3e96b268..ea019f22 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,8 +4,7 @@ import pytest import sqlalchemy as sa -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import DeclarativeBase, sessionmaker from sqlalchemy.schema import UniqueConstraint from pgsync.base import Base, create_database, drop_database @@ -19,7 +18,10 @@ @pytest.fixture(scope="session") def base(): - return declarative_base() + class Base(DeclarativeBase): + pass + + return Base @pytest.fixture(scope="session") @@ -431,7 +433,9 @@ def model_mapping( @pytest.fixture(scope="session") def table_creator(base, connection, model_mapping): sa.orm.configure_mappers() - base.metadata.create_all(connection) + with connection.engine.connect() as conn: + base.metadata.create_all(connection.engine) + conn.commit() pg_base = Base(connection.engine.url.database) pg_base.create_triggers( connection.engine.url.database, @@ -441,7 +445,9 @@ def table_creator(base, connection, model_mapping): pg_base.create_replication_slot(f"{connection.engine.url.database}_testdb") yield pg_base.drop_replication_slot(f"{connection.engine.url.database}_testdb") - base.metadata.drop_all(connection) + with connection.engine.connect() as conn: + base.metadata.drop_all(connection.engine) + conn.commit() try: os.unlink(f".{connection.engine.url.database}_testdb") except (OSError, FileNotFoundError): diff --git a/tests/test_base.py b/tests/test_base.py index 582ac1b0..ad294055 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -2,7 +2,7 @@ import pytest import sqlalchemy as sa -from mock import ANY, call, patch +from mock import call, patch from pgsync.base import ( _pg_engine, @@ -152,9 +152,9 @@ def test_columns(self, connection): @patch("pgsync.base.logger") @patch("pgsync.sync.Base.execute") - @patch("pgsync.base.sa.DDL") + @patch("pgsync.base.sa.text") def test_truncate_table( - self, mock_ddl, mock_execute, mock_logger, connection + self, mock_text, mock_execute, mock_logger, connection ): pg_base = Base(connection.engine.url.database) pg_base.truncate_table("book") @@ -162,7 +162,7 @@ def test_truncate_table( "Truncating table: public.book" ) mock_execute.assert_called_once() - mock_ddl.assert_called_once_with( + mock_text.assert_called_once_with( 'TRUNCATE TABLE "public"."book" CASCADE' ) @@ -232,9 +232,7 @@ def test_replication_slots(self, connection): @patch("pgsync.base.logger") def test_create_replication_slot(self, mock_logger, connection): pg_base = Base(connection.engine.url.database) - row = pg_base.create_replication_slot("slot_name") - assert row[0] == "slot_name" - assert row[1] is not None + pg_base.create_replication_slot("slot_name") pg_base.drop_replication_slot("slot_name") calls = [ call("Creating replication slot: slot_name"), @@ -374,28 +372,26 @@ def test_drop_extension( mock_pg_execute.call_args_list == calls @patch("pgsync.base.logger") - @patch("pgsync.sync.Base.engine") - def test_drop_view(self, mock_engine, mock_logger, connection): + def test_drop_view(self, mock_logger, connection): pg_base = Base(connection.engine.url.database) - pg_base.drop_view("public") - calls = [ - call("Dropping view: public._view"), - call("Dropped view: public._view"), - ] - assert mock_logger.debug.call_args_list == calls - mock_engine.execute.assert_called_once_with(ANY) + with patch("pgsync.sync.Base.engine"): + pg_base.drop_view("public") + calls = [ + call("Dropping view: public._view"), + call("Dropped view: public._view"), + ] + assert mock_logger.debug.call_args_list == calls @patch("pgsync.base.logger") - @patch("pgsync.sync.Base.engine") - def test_refresh_view(self, mock_engine, mock_logger, connection): + def test_refresh_view(self, mock_logger, connection): pg_base = Base(connection.engine.url.database) - pg_base.refresh_view("foo", "public", concurrently=True) - calls = [ - call("Refreshing view: public.foo"), - call("Refreshed view: public.foo"), - ] - assert mock_logger.debug.call_args_list == calls - mock_engine.execute.assert_called_once_with(ANY) + with patch("pgsync.sync.Base.engine"): + pg_base.refresh_view("foo", "public", concurrently=True) + calls = [ + call("Refreshing view: public.foo"), + call("Refreshed view: public.foo"), + ] + assert mock_logger.debug.call_args_list == calls def test_parse_value(self, connection): pg_base = Base(connection.engine.url.database) @@ -457,10 +453,11 @@ def test_parse_logical_slot( def test_fetchone(self, connection): pg_base = Base(connection.engine.url.database, verbose=True) with patch("pgsync.base.compiled_query") as mock_compiled_query: - row = pg_base.fetchone("SELECT 1", label="foo", literal_binds=True) + statement = sa.text("SELECT 1") + row = pg_base.fetchone(statement, label="foo", literal_binds=True) assert row == (1,) mock_compiled_query.assert_called_once_with( - "SELECT 1", label="foo", literal_binds=True + statement, label="foo", literal_binds=True ) with pytest.raises(sa.exc.ProgrammingError): @@ -471,10 +468,11 @@ def test_fetchone(self, connection): def test_fetchall(self, connection): pg_base = Base(connection.engine.url.database, verbose=True) with patch("pgsync.base.compiled_query") as mock_compiled_query: - row = pg_base.fetchall("SELECT 1", label="foo", literal_binds=True) + statement = sa.text("SELECT 1") + row = pg_base.fetchall(statement, label="foo", literal_binds=True) assert row == [(1,)] mock_compiled_query.assert_called_once_with( - "SELECT 1", label="foo", literal_binds=True + statement, label="foo", literal_binds=True ) with pytest.raises(sa.exc.ProgrammingError): @@ -489,45 +487,59 @@ def test_count(self, connection, book_cls): def test_views(self, connection): pg_base = Base(connection.engine.url.database) - connection.engine.execute( - CreateView( - DEFAULT_SCHEMA, "mymatview", sa.select(1), materialized=True + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, + "mymatview", + sa.select(1), + materialized=True, + ) ) - ) - connection.engine.execute( - CreateView( - DEFAULT_SCHEMA, "myview", sa.select(1), materialized=False + conn.execute( + CreateView( + DEFAULT_SCHEMA, "myview", sa.select(1), materialized=False + ) ) - ) + conn.commit() views = pg_base._views(DEFAULT_SCHEMA) assert views == ["myview"] - connection.engine.execute( - DropView(DEFAULT_SCHEMA, "mymatview", materialized=True) - ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, "myview", materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute( + DropView(DEFAULT_SCHEMA, "mymatview", materialized=True) + ) + conn.execute( + DropView(DEFAULT_SCHEMA, "myview", materialized=False) + ) + conn.commit() def test_materialized_views(self, connection): pg_base = Base(connection.engine.url.database) - connection.engine.execute( - CreateView( - DEFAULT_SCHEMA, "mymatview", sa.select(1), materialized=True + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, + "mymatview", + sa.select(1), + materialized=True, + ) ) - ) - connection.engine.execute( - CreateView( - DEFAULT_SCHEMA, "myview", sa.select(1), materialized=False + conn.execute( + CreateView( + DEFAULT_SCHEMA, "myview", sa.select(1), materialized=False + ) ) - ) + conn.commit() views = pg_base._materialized_views(DEFAULT_SCHEMA) assert views == ["mymatview"] - connection.engine.execute( - DropView(DEFAULT_SCHEMA, "mymatview", materialized=True) - ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, "myview", materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute( + DropView(DEFAULT_SCHEMA, "mymatview", materialized=True) + ) + conn.execute( + DropView(DEFAULT_SCHEMA, "myview", materialized=False) + ) + conn.commit() def test_pg_execute(self, connection): with patch("pgsync.base.logger") as mock_logger: @@ -537,15 +549,14 @@ def test_pg_execute(self, connection): options={"isolation_level": "AUTOCOMMIT"}, ) mock_logger.exception.assert_not_called() - with patch("pgsync.base.logger") as mock_logger: - with pytest.raises(Exception) as excinfo: - pg_execute( - connection.engine, - sa.select(1), - options={None: "AUTOCOMMIT"}, - ) - mock_logger.exception.assert_called_once() - assert "must be strings" in str(excinfo.value) + + with pytest.raises(Exception) as excinfo: + pg_execute( + connection.engine, + sa.select(1), + options={None: "AUTOCOMMIT"}, + ) + assert "must be strings" in str(excinfo.value) def test_pg_engine(self, connection): with pytest.raises(ValueError) as excinfo: diff --git a/tests/test_settings.py b/tests/test_settings.py index 614b8488..6a078f3a 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -27,7 +27,7 @@ def test_postgres_url(mocker): mocker.patch("logging.config.dictConfig") engine = _pg_engine("wheel") mock_get_postgres_url.assert_called_once() - url = "postgresql://kermit:frog@some-host:5432/wheel" + url = "postgresql://kermit:***@some-host:5432/wheel" assert str(engine.engine.url) == url diff --git a/tests/test_sync_nested_children.py b/tests/test_sync_nested_children.py index 8d9a81dc..bece03fc 100644 --- a/tests/test_sync_nested_children.py +++ b/tests/test_sync_nested_children.py @@ -213,9 +213,6 @@ def data( with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_sync_root.py b/tests/test_sync_root.py index 1e51d124..51d986b4 100644 --- a/tests/test_sync_root.py +++ b/tests/test_sync_root.py @@ -1,6 +1,5 @@ """Tests for `pgsync` package.""" import mock -import psycopg2 import pytest from pgsync.base import subtransactions @@ -47,9 +46,6 @@ def data(self, sync, book_cls, publisher_cls): with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") @@ -67,9 +63,6 @@ def data(self, sync, book_cls, publisher_cls): with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = session.connection().engine.url.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_trigger.py b/tests/test_trigger.py index ab9256ab..f86492b4 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -1,5 +1,6 @@ """Trigger tests.""" import pytest +import sqlalchemy as sa from pgsync.base import Base from pgsync.trigger import CREATE_TRIGGER_TEMPLATE @@ -107,7 +108,7 @@ def test_trigger_primary_key_function(self, connection): f"JOIN pg_attribute ON attrelid = indrelid AND attnum = ANY(indkey) " # noqa E501 f"WHERE indrelid = '{table_name}'::regclass AND indisprimary" ) - rows = pg_base.fetchall(query)[0] + rows = pg_base.fetchall(sa.text(query))[0] assert list(rows)[0] == primary_keys def test_trigger_foreign_key_function(self, connection): @@ -129,7 +130,7 @@ def test_trigger_foreign_key_function(self, connection): f"WHERE constraint_catalog=current_catalog AND " f"table_name='{table_name}' AND position_in_unique_constraint NOTNULL " # noqa E501 ) - rows = pg_base.fetchall(query)[0] + rows = pg_base.fetchall(sa.text(query))[0] if rows[0]: assert sorted(rows[0]) == sorted(foreign_keys) else: diff --git a/tests/test_unique_behaviour.py b/tests/test_unique_behaviour.py index 84bb6c65..5e3f6d21 100644 --- a/tests/test_unique_behaviour.py +++ b/tests/test_unique_behaviour.py @@ -1,6 +1,5 @@ """Tests for `pgsync` package.""" -import psycopg2 import pytest from pgsync.base import subtransactions @@ -45,9 +44,6 @@ def data( ] with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = sync.database cursor.execute(f"UNLISTEN {channel}") @@ -72,9 +68,6 @@ def data( with subtransactions(session): conn = session.connection().engine.connect().connection - conn.set_isolation_level( - psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT - ) cursor = conn.cursor() channel = session.connection().engine.url.database cursor.execute(f"UNLISTEN {channel}") diff --git a/tests/test_utils.py b/tests/test_utils.py index e75afe42..27b24fd1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -122,7 +122,7 @@ def test_compiled_query_with_label( ): pg_base = Base(connection.engine.url.database) model = pg_base.models("book", "public") - statement = sa.select([model.c.isbn]).select_from(model) + statement = sa.select(*[model.c.isbn]).select_from(model) compiled_query(statement, label="foo", literal_binds=True) mock_logger.debug.assert_called_once_with( "\x1b[4mfoo:\x1b[0m\nSELECT book_1.isbn\n" @@ -137,7 +137,7 @@ def test_compiled_query_without_label( ): pg_base = Base(connection.engine.url.database) model = pg_base.models("book", "public") - statement = sa.select([model.c.isbn]).select_from(model) + statement = sa.select(*[model.c.isbn]).select_from(model) compiled_query(statement, literal_binds=True) mock_logger.debug.assert_called_once_with( "SELECT book_1.isbn\nFROM public.book AS book_1" diff --git a/tests/test_view.py b/tests/test_view.py index ad7dd374..36cc7fe6 100644 --- a/tests/test_view.py +++ b/tests/test_view.py @@ -35,6 +35,10 @@ def data(self, sync, book_cls): with subtransactions(session): session.add_all(books) yield books + + with subtransactions(session): + session.query(book_cls).delete() + session.connection().engine.connect().close() session.connection().engine.dispose() sync.search_client.close() @@ -42,9 +46,13 @@ def data(self, sync, book_cls): def test_create_materialized_view(self, connection): """Test create materialized view.""" view = "test_mat_view" - connection.engine.execute( - CreateView(DEFAULT_SCHEMA, view, sa.select(1), materialized=True) - ) + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, view, sa.select(1), materialized=True + ) + ) + conn.commit() assert ( is_view(connection.engine, DEFAULT_SCHEMA, view, materialized=True) is True @@ -55,16 +63,20 @@ def test_create_materialized_view(self, connection): ) is False ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, view, materialized=True) - ) + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=True)) + conn.commit() def test_create_view(self, connection): """Test create non-materialized view.""" view = "test_view" - connection.engine.execute( - CreateView(DEFAULT_SCHEMA, view, sa.select(1), materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, view, sa.select(1), materialized=False + ) + ) + conn.commit() assert ( is_view( connection.engine, DEFAULT_SCHEMA, view, materialized=False @@ -75,23 +87,27 @@ def test_create_view(self, connection): is_view(connection.engine, DEFAULT_SCHEMA, view, materialized=True) is False ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, view, materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=False)) + conn.commit() def test_drop_materialized_view(self, connection): """Test drop materialized view.""" view = "test_view_drop" - connection.engine.execute( - CreateView(DEFAULT_SCHEMA, view, sa.select(1), materialized=True) - ) + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, view, sa.select(1), materialized=True + ) + ) + conn.commit() assert ( is_view(connection.engine, DEFAULT_SCHEMA, view, materialized=True) is True ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, view, materialized=True) - ) + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=True)) + conn.commit() assert ( is_view(connection.engine, DEFAULT_SCHEMA, view, materialized=True) is False @@ -106,18 +122,22 @@ def test_drop_materialized_view(self, connection): def test_drop_view(self, connection): """Test drop non-materialized view.""" view = "test_view_drop" - connection.engine.execute( - CreateView(DEFAULT_SCHEMA, view, sa.select(1), materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute( + CreateView( + DEFAULT_SCHEMA, view, sa.select(1), materialized=False + ) + ) + conn.commit() assert ( is_view( connection.engine, DEFAULT_SCHEMA, view, materialized=False ) is True ) - connection.engine.execute( - DropView(DEFAULT_SCHEMA, view, materialized=False) - ) + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=False)) + conn.commit() assert ( is_view(connection.engine, DEFAULT_SCHEMA, view, materialized=True) is False @@ -134,17 +154,20 @@ def test_refresh_view(self, connection, sync, book_cls, data): """Test refresh materialized view.""" view = "test_view_refresh" pg_base = Base(connection.engine.url.database) + model = pg_base.models("book", "public") - statement = sa.select([model.c.isbn]).select_from(model) - connection.engine.execute( - CreateView(DEFAULT_SCHEMA, view, statement, materialized=True) - ) - assert [ - result.isbn - for result in connection.engine.execute( - sa.text(f"SELECT * FROM {view}") + statement = sa.select(*[model.c.isbn]).select_from(model) + with connection.engine.connect() as conn: + conn.execute( + CreateView(DEFAULT_SCHEMA, view, statement, materialized=True) ) - ][0] == "abc" + conn.commit() + + with connection.engine.connect() as conn: + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "abc" session = sync.session with subtransactions(session): @@ -154,26 +177,27 @@ def test_refresh_view(self, connection, sync, book_cls, data): .values(isbn="xyz") ) - # the value should still be abc - assert [ - result.isbn - for result in connection.engine.execute( - sa.text(f"SELECT * FROM {view}") - ) - ][0] == "abc" + with connection.engine.connect() as conn: + # the value should still be abc + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "abc" - connection.engine.execute(RefreshView(DEFAULT_SCHEMA, view)) + with connection.engine.connect() as conn: + conn.execute(RefreshView(DEFAULT_SCHEMA, view)) + conn.commit() - # the value should now be xyz - assert [ - result.isbn - for result in connection.engine.execute( - sa.text(f"SELECT * FROM {view}") - ) - ][0] == "xyz" - connection.engine.execute( - DropView(DEFAULT_SCHEMA, view, materialized=True) - ) + with connection.engine.connect() as conn: + # the value should now be xyz + assert [ + result.isbn + for result in conn.execute(sa.text(f"SELECT * FROM {view}")) + ][0] == "xyz" + + with connection.engine.connect() as conn: + conn.execute(DropView(DEFAULT_SCHEMA, view, materialized=True)) + conn.commit() @pytest.mark.usefixtures("table_creator") def test_index(self, connection, sync, book_cls, data): @@ -182,9 +206,11 @@ def test_index(self, connection, sync, book_cls, data): "book", schema=DEFAULT_SCHEMA ) assert indices == [] - connection.engine.execute( - CreateIndex("my_index", DEFAULT_SCHEMA, "book", ["isbn"]) - ) + with connection.engine.connect() as conn: + conn.execute( + CreateIndex("my_index", DEFAULT_SCHEMA, "book", ["isbn"]) + ) + conn.commit() indices = sa.inspect(connection.engine).get_indexes( "book", schema=DEFAULT_SCHEMA ) @@ -198,7 +224,9 @@ def test_index(self, connection, sync, book_cls, data): "postgresql_include": [], }, } - connection.engine.execute(DropIndex("my_index")) + with connection.engine.connect() as conn: + conn.execute(DropIndex("my_index")) + conn.commit() indices = sa.inspect(connection.engine).get_indexes( "book", schema=DEFAULT_SCHEMA )