Skip to content

Commit

Permalink
Extend get method to accept different fields
Browse files Browse the repository at this point in the history
  • Loading branch information
koldakov committed Jan 24, 2024
1 parent e986cba commit 726a705
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
8 changes: 6 additions & 2 deletions app/repositories/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,14 @@ def to_dict(
async def get(
cls,
session: AsyncSession,
id_: int,
val: int | str,
/,
*,
field: InstrumentedAttribute = None,
) -> T:
cursor: Result = await session.execute(select(cls).where(cls.id == id_))
if field is None:
field = cls.id
cursor: Result = await session.execute(select(cls).where(field == val))
try:
return cursor.scalars().one()
except NoResultFound as err:
Expand Down
33 changes: 25 additions & 8 deletions app/repositories/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sqlalchemy.exc import NoResultFound
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.sql.elements import BinaryExpression

from app.configs import settings
Expand Down Expand Up @@ -131,12 +132,16 @@ class Season(Base):
async def get(
cls,
session: AsyncSession,
id_: int,
val: int | str,
/,
*,
field: InstrumentedAttribute = None,
) -> "Season":
if field is None:
field = Season.id
cursor: Result = await session.execute(
select(Season)
.where(Season.id == id_)
.where(field == val)
.options(selectinload(Season.episodes))
)
try:
Expand Down Expand Up @@ -211,12 +216,16 @@ class Episode(Base):
async def get(
cls,
session: AsyncSession,
id_: int,
val: int | str,
/,
*,
field: InstrumentedAttribute = None,
) -> "Episode":
if field is None:
field = Episode.id
cursor: Result = await session.execute(
select(Episode)
.where(Episode.id == id_)
.where(field == val)
.options(selectinload(Episode.season))
)
try:
Expand Down Expand Up @@ -280,11 +289,15 @@ class Character(Base):
async def get(
cls,
session: AsyncSession,
id_: int,
val: int | str,
/,
*,
field: InstrumentedAttribute = None,
) -> "Character":
if field is None:
field = Character.id
cursor: Result = await session.execute(
select(Character).where(Character.id == id_)
select(Character).where(field == val)
)
try:
return cursor.scalars().one()
Expand Down Expand Up @@ -392,10 +405,14 @@ class User(Base):
async def get(
cls,
session: AsyncSession,
id_: int,
val: int | str,
/,
*,
field: InstrumentedAttribute = None,
) -> "User":
cursor: Result = await session.execute(select(User).where(User.id == id_))
if field is None:
field = User.id
cursor: Result = await session.execute(select(User).where(field == val))
try:
return cursor.scalars().one()
except NoResultFound as err:
Expand Down

0 comments on commit 726a705

Please sign in to comment.