From 726a70587ae99f74b5e959c2806e570e2b8ed67c Mon Sep 17 00:00:00 2001 From: Ivan Koldakov Date: Wed, 24 Jan 2024 21:52:19 +0100 Subject: [PATCH] Extend get method to accept different fields --- app/repositories/base.py | 8 ++++++-- app/repositories/models.py | 33 +++++++++++++++++++++++++-------- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/app/repositories/base.py b/app/repositories/base.py index 623a3e8..be54b3e 100644 --- a/app/repositories/base.py +++ b/app/repositories/base.py @@ -74,10 +74,14 @@ def to_dict( async def get( cls, session: AsyncSession, - id_: int, + val: int | str, /, + *, + field: InstrumentedAttribute = None, ) -> T: - cursor: Result = await session.execute(select(cls).where(cls.id == id_)) + if field is None: + field = cls.id + cursor: Result = await session.execute(select(cls).where(field == val)) try: return cursor.scalars().one() except NoResultFound as err: diff --git a/app/repositories/models.py b/app/repositories/models.py index 7f272c0..48952bc 100644 --- a/app/repositories/models.py +++ b/app/repositories/models.py @@ -18,6 +18,7 @@ from sqlalchemy.exc import NoResultFound from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload +from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.sql.elements import BinaryExpression from app.configs import settings @@ -131,12 +132,16 @@ class Season(Base): async def get( cls, session: AsyncSession, - id_: int, + val: int | str, /, + *, + field: InstrumentedAttribute = None, ) -> "Season": + if field is None: + field = Season.id cursor: Result = await session.execute( select(Season) - .where(Season.id == id_) + .where(field == val) .options(selectinload(Season.episodes)) ) try: @@ -211,12 +216,16 @@ class Episode(Base): async def get( cls, session: AsyncSession, - id_: int, + val: int | str, /, + *, + field: InstrumentedAttribute = None, ) -> "Episode": + if field is None: + field = Episode.id cursor: Result = await session.execute( select(Episode) - .where(Episode.id == id_) + .where(field == val) .options(selectinload(Episode.season)) ) try: @@ -280,11 +289,15 @@ class Character(Base): async def get( cls, session: AsyncSession, - id_: int, + val: int | str, /, + *, + field: InstrumentedAttribute = None, ) -> "Character": + if field is None: + field = Character.id cursor: Result = await session.execute( - select(Character).where(Character.id == id_) + select(Character).where(field == val) ) try: return cursor.scalars().one() @@ -392,10 +405,14 @@ class User(Base): async def get( cls, session: AsyncSession, - id_: int, + val: int | str, /, + *, + field: InstrumentedAttribute = None, ) -> "User": - cursor: Result = await session.execute(select(User).where(User.id == id_)) + if field is None: + field = User.id + cursor: Result = await session.execute(select(User).where(field == val)) try: return cursor.scalars().one() except NoResultFound as err: