Skip to content

Commit

Permalink
first round
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov committed Dec 10, 2024
1 parent bf07906 commit ed2a0d5
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,12 @@
from collections.abc import AsyncIterator, Callable

from aiohttp import web
from aiopg.sa.engine import Engine
from models_library.users import UserID
from servicelib.logging_utils import get_log_record_extra, log_context
from tenacity import retry
from tenacity.before_sleep import before_sleep_log
from tenacity.wait import wait_exponential

from ..db.plugin import get_database_engine
from ..login.utils import notify_user_logout
from ..security.api import clean_auth_policy_cache
from ..users.api import update_expired_users
Expand Down Expand Up @@ -60,10 +58,8 @@ async def _update_expired_users(app: web.Application):
"""
It is resilient, i.e. if update goes wrong, it waits a bit and retries
"""
engine: Engine = get_database_engine(app)
assert engine # nosec

if updated := await update_expired_users(engine):
if updated := await update_expired_users(app):
# expired users might be cached in the auth. If so, any request
# with this user-id will get thru producing unexpected side-effects
await clean_auth_policy_cache(app)
Expand Down
20 changes: 11 additions & 9 deletions services/web/server/src/simcore_service_webserver/users/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from simcore_postgres_database.models.users import UserStatus

from ..db.plugin import get_database_engine
from . import _db, _schemas
from ._db import get_user_or_raise
from ._db import list_user_permissions as db_list_of_permissions
from ._db import update_user_status
from . import _schemas, _users_repository
from ._users_repository import get_user_or_raise
from ._users_repository import list_user_permissions as db_list_of_permissions
from ._users_repository import update_user_status
from .exceptions import AlreadyPreRegisteredError
from .schemas import Permission

Expand Down Expand Up @@ -73,13 +73,13 @@ async def search_users(
app: web.Application, email_glob: str, *, include_products: bool = False
) -> list[_schemas.UserProfile]:
# NOTE: this search is deploy-wide i.e. independent of the product!
rows = await _db.search_users_and_get_profile(
rows = await _users_repository.search_users_and_get_profile(
get_database_engine(app), email_like=_glob_to_sql_like(email_glob)
)

async def _list_products_or_none(user_id):
if user_id is not None and include_products:
products = await _db.get_user_products(
products = await _users_repository.get_user_products(
get_database_engine(app), user_id=user_id
)
return [_.product_name for _ in products]
Expand Down Expand Up @@ -136,7 +136,7 @@ async def pre_register_user(
if key in details:
details[f"pre_{key}"] = details.pop(key)

await _db.new_user_details(
await _users_repository.new_user_details(
get_database_engine(app),
email=profile.email,
created_by=creator_user_id,
Expand All @@ -152,8 +152,10 @@ async def pre_register_user(
async def get_user_invoice_address(
app: web.Application, user_id: UserID
) -> UserInvoiceAddress:
user_billing_details: UserBillingDetails = await _db.get_user_billing_details(
get_database_engine(app), user_id=user_id
user_billing_details: UserBillingDetails = (
await _users_repository.get_user_billing_details(
get_database_engine(app), user_id=user_id
)
)
_user_billing_country = pycountry.countries.lookup(user_billing_details.country)
_user_billing_country_alpha_2_format = _user_billing_country.alpha_2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import sqlalchemy as sa
from aiohttp import web
from aiopg.sa.connection import SAConnection
from aiopg.sa.engine import Engine
from aiopg.sa.result import ResultProxy, RowProxy
from models_library.users import GroupID, UserBillingDetails, UserID
from simcore_postgres_database.models.groups import groups, user_to_groups
from simcore_postgres_database.models.products import products
Expand All @@ -16,59 +13,78 @@
GroupExtraPropertiesNotFoundError,
GroupExtraPropertiesRepo,
)
from simcore_postgres_database.utils_repos import (
pass_or_acquire_connection,
transaction_context,
)
from simcore_postgres_database.utils_users import UsersRepo
from simcore_service_webserver.users.exceptions import UserNotFoundError
from sqlalchemy.engine.row import Row
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine

from ..db.models import user_to_groups
from ..db.plugin import get_database_engine
from ..db.plugin import get_asyncpg_engine
from .exceptions import BillingDetailsNotFoundError
from .schemas import Permission

_ALL = None


async def get_user_or_raise(
engine: Engine, *, user_id: UserID, return_column_names: list[str] | None = _ALL
) -> RowProxy:
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
return_column_names: list[str] | None = _ALL,
) -> Row:
if return_column_names == _ALL:
return_column_names = list(users.columns.keys())

assert return_column_names is not None # nosec
assert set(return_column_names).issubset(users.columns.keys()) # nosec

async with engine.acquire() as conn:
row: RowProxy | None = await (
await conn.execute(
sa.select(*(users.columns[name] for name in return_column_names)).where(
users.c.id == user_id
)
async with pass_or_acquire_connection(engine, connection) as conn:
result = await conn.stream(
sa.select(*(users.columns[name] for name in return_column_names)).where(
users.c.id == user_id
)
).first()
)
row = await result.first()
if row is None:
raise UserNotFoundError(uid=user_id)
return row


async def get_users_ids_in_group(conn: SAConnection, gid: GroupID) -> set[UserID]:
result: set[UserID] = set()
query_result = await conn.execute(
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == gid)
)
async for entry in query_result:
result.add(entry[0])
return result
async def get_users_ids_in_group(
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
group_id: GroupID,
) -> set[UserID]:
async with pass_or_acquire_connection(engine, connection) as conn:
result = await conn.stream(
sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == group_id)
)
return {row.uid async for row in result}


async def list_user_permissions(
app: web.Application, *, user_id: UserID, product_name: str
app: web.Application,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
product_name: str,
) -> list[Permission]:
override_services_specifications = Permission(
name="override_services_specifications",
allowed=False,
)
with contextlib.suppress(GroupExtraPropertiesNotFoundError):
async with get_database_engine(app).acquire() as conn:
async with pass_or_acquire_connection(
get_asyncpg_engine(app), connection
) as conn:
user_group_extra_properties = (
# TODO: adapt to asyncpg
await GroupExtraPropertiesRepo.get_aggregated_properties_for_user(
conn, user_id=user_id, product_name=product_name
)
Expand All @@ -80,34 +96,43 @@ async def list_user_permissions(
return [override_services_specifications]


async def do_update_expired_users(conn: SAConnection) -> list[UserID]:
result: ResultProxy = await conn.execute(
users.update()
.values(status=UserStatus.EXPIRED)
.where(
(users.c.expires_at.is_not(None))
& (users.c.status == UserStatus.ACTIVE)
& (users.c.expires_at < sa.sql.func.now())
async def do_update_expired_users(
engine: AsyncEngine,
connection: AsyncConnection | None = None,
) -> list[UserID]:
async with transaction_context(engine, connection) as conn:
result = await conn.stream(
users.update()
.values(status=UserStatus.EXPIRED)
.where(
(users.c.expires_at.is_not(None))
& (users.c.status == UserStatus.ACTIVE)
& (users.c.expires_at < sa.sql.func.now())
)
.returning(users.c.id)
)
.returning(users.c.id)
)
if rows := await result.fetchall():
return [r.id for r in rows]
return []
return [row.id async for row in result]


async def update_user_status(
engine: Engine, *, user_id: UserID, new_status: UserStatus
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
new_status: UserStatus,
):
async with engine.acquire() as conn:
async with transaction_context(engine, connection) as conn:
await conn.execute(
users.update().values(status=new_status).where(users.c.id == user_id)
)


async def search_users_and_get_profile(
engine: Engine, *, email_like: str
) -> list[RowProxy]:
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
email_like: str,
) -> list[Row]:

users_alias = sa.alias(users, name="users_alias")

Expand All @@ -117,7 +142,7 @@ async def search_users_and_get_profile(
.label("invited_by")
)

async with engine.acquire() as conn:
async with pass_or_acquire_connection(engine, connection) as conn:
columns = (
users.c.first_name,
users.c.last_name,
Expand Down Expand Up @@ -159,12 +184,17 @@ async def search_users_and_get_profile(
.where(users.c.email.like(email_like))
)

result = await conn.execute(sa.union(left_outer_join, right_outer_join))
return await result.fetchall() or []
result = await conn.stream(sa.union(left_outer_join, right_outer_join))
return [row async for row in result]


async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
async with engine.acquire() as conn:
async def get_user_products(
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
user_id: UserID,
) -> list[Row]:
async with pass_or_acquire_connection(engine, connection) as conn:
product_name_subq = (
sa.select(products.c.name)
.where(products.c.group_id == groups.c.gid)
Expand All @@ -186,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]:
.where(users.c.id == user_id)
.order_by(groups.c.gid)
)
result = await conn.execute(query)
return await result.fetchall() or []
result = await conn.stream(query)
return [row async for row in result]


async def new_user_details(
engine: Engine, email: str, created_by: UserID, **other_values
engine: AsyncEngine,
connection: AsyncConnection | None = None,
*,
email: str,
created_by: UserID,
**other_values,
) -> None:
async with engine.acquire() as conn:
async with transaction_context(engine, connection) as conn:
await conn.execute(
sa.insert(users_pre_registration_details).values(
created_by=created_by, pre_email=email, **other_values
Expand All @@ -202,13 +237,13 @@ async def new_user_details(


async def get_user_billing_details(
engine: Engine, user_id: UserID
engine: AsyncEngine, connection: AsyncConnection | None = None, *, user_id: UserID
) -> UserBillingDetails:
"""
Raises:
BillingDetailsNotFoundError
"""
async with engine.acquire() as conn:
async with pass_or_acquire_connection(engine, connection) as conn:
user_billing_details = await UsersRepo.get_billing_details(conn, user_id)
if not user_billing_details:
raise BillingDetailsNotFoundError(user_id=user_id)
Expand Down
28 changes: 14 additions & 14 deletions services/web/server/src/simcore_service_webserver/users/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import simcore_postgres_database.errors as db_errors
import sqlalchemy as sa
from aiohttp import web
from aiopg.sa.engine import Engine
from aiopg.sa.result import RowProxy
from models_library.api_schemas_webserver.users import (
ProfileGet,
Expand All @@ -29,11 +28,11 @@
GroupExtraPropertiesNotFoundError,
)

from ..db.plugin import get_database_engine
from ..db.plugin import get_asyncpg_engine, get_database_engine
from ..groups.models import convert_groups_db_to_schema
from ..login.storage import AsyncpgStorage, get_plugin_storage
from ..security.api import clean_auth_policy_cache
from . import _db
from . import _users_repository
from ._api import get_user_credentials, get_user_invoice_address, set_user_as_deleted
from ._models import ToUserUpdateDB
from ._preferences_api import get_frontend_user_preferences_aggregation
Expand Down Expand Up @@ -216,8 +215,8 @@ async def get_user_name_and_email(
Returns:
(user, email)
"""
row = await _db.get_user_or_raise(
get_database_engine(app),
row = await _users_repository.get_user_or_raise(
get_asyncpg_engine(app),
user_id=_parse_as_user(user_id),
return_column_names=["name", "email"],
)
Expand All @@ -242,8 +241,8 @@ async def get_user_display_and_id_names(
Raises:
UserNotFoundError
"""
row = await _db.get_user_or_raise(
get_database_engine(app),
row = await _users_repository.get_user_or_raise(
get_asyncpg_engine(app),
user_id=_parse_as_user(user_id),
return_column_names=["name", "email", "first_name", "last_name"],
)
Expand Down Expand Up @@ -318,7 +317,9 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]:
"""
:raises UserNotFoundError:
"""
row = await _db.get_user_or_raise(engine=get_database_engine(app), user_id=user_id)
row = await _users_repository.get_user_or_raise(
engine=get_asyncpg_engine(app), user_id=user_id
)
return dict(row)


Expand All @@ -332,14 +333,13 @@ async def get_user_id_from_gid(app: web.Application, primary_gid: int) -> UserID


async def get_users_in_group(app: web.Application, gid: GroupID) -> set[UserID]:
engine = get_database_engine(app)
async with engine.acquire() as conn:
return await _db.get_users_ids_in_group(conn, gid)
return await _users_repository.get_users_ids_in_group(
get_asyncpg_engine(app), group_id=gid
)


async def update_expired_users(engine: Engine) -> list[UserID]:
async with engine.acquire() as conn:
return await _db.do_update_expired_users(conn)
async def update_expired_users(app: web.Application) -> list[UserID]:
return await _users_repository.do_update_expired_users(get_asyncpg_engine(app))


assert set_user_as_deleted # nosec
Expand Down
Loading

0 comments on commit ed2a0d5

Please sign in to comment.