From ed2a0d585aade18811d50fb312790e775508cca6 Mon Sep 17 00:00:00 2001 From: Pedro Crespo-Valero <32402063+pcrespov@users.noreply.github.com> Date: Tue, 10 Dec 2024 19:43:57 +0100 Subject: [PATCH] first round --- .../garbage_collector/_tasks_users.py | 6 +- .../simcore_service_webserver/users/_api.py | 20 +-- .../users/{_db.py => _users_repository.py} | 137 +++++++++++------- .../simcore_service_webserver/users/api.py | 28 ++-- .../tests/unit/with_dbs/03/test_users_api.py | 3 +- 5 files changed, 113 insertions(+), 81 deletions(-) rename services/web/server/src/simcore_service_webserver/users/{_db.py => _users_repository.py} (62%) diff --git a/services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py b/services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py index 48d781aee8de..e99f9c4a225e 100644 --- a/services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py +++ b/services/web/server/src/simcore_service_webserver/garbage_collector/_tasks_users.py @@ -8,14 +8,12 @@ from collections.abc import AsyncIterator, Callable from aiohttp import web -from aiopg.sa.engine import Engine from models_library.users import UserID from servicelib.logging_utils import get_log_record_extra, log_context from tenacity import retry from tenacity.before_sleep import before_sleep_log from tenacity.wait import wait_exponential -from ..db.plugin import get_database_engine from ..login.utils import notify_user_logout from ..security.api import clean_auth_policy_cache from ..users.api import update_expired_users @@ -60,10 +58,8 @@ async def _update_expired_users(app: web.Application): """ It is resilient, i.e. if update goes wrong, it waits a bit and retries """ - engine: Engine = get_database_engine(app) - assert engine # nosec - if updated := await update_expired_users(engine): + if updated := await update_expired_users(app): # expired users might be cached in the auth. If so, any request # with this user-id will get thru producing unexpected side-effects await clean_auth_policy_cache(app) diff --git a/services/web/server/src/simcore_service_webserver/users/_api.py b/services/web/server/src/simcore_service_webserver/users/_api.py index 458366367f5f..b0091c77f39e 100644 --- a/services/web/server/src/simcore_service_webserver/users/_api.py +++ b/services/web/server/src/simcore_service_webserver/users/_api.py @@ -10,10 +10,10 @@ from simcore_postgres_database.models.users import UserStatus from ..db.plugin import get_database_engine -from . import _db, _schemas -from ._db import get_user_or_raise -from ._db import list_user_permissions as db_list_of_permissions -from ._db import update_user_status +from . import _schemas, _users_repository +from ._users_repository import get_user_or_raise +from ._users_repository import list_user_permissions as db_list_of_permissions +from ._users_repository import update_user_status from .exceptions import AlreadyPreRegisteredError from .schemas import Permission @@ -73,13 +73,13 @@ async def search_users( app: web.Application, email_glob: str, *, include_products: bool = False ) -> list[_schemas.UserProfile]: # NOTE: this search is deploy-wide i.e. independent of the product! - rows = await _db.search_users_and_get_profile( + rows = await _users_repository.search_users_and_get_profile( get_database_engine(app), email_like=_glob_to_sql_like(email_glob) ) async def _list_products_or_none(user_id): if user_id is not None and include_products: - products = await _db.get_user_products( + products = await _users_repository.get_user_products( get_database_engine(app), user_id=user_id ) return [_.product_name for _ in products] @@ -136,7 +136,7 @@ async def pre_register_user( if key in details: details[f"pre_{key}"] = details.pop(key) - await _db.new_user_details( + await _users_repository.new_user_details( get_database_engine(app), email=profile.email, created_by=creator_user_id, @@ -152,8 +152,10 @@ async def pre_register_user( async def get_user_invoice_address( app: web.Application, user_id: UserID ) -> UserInvoiceAddress: - user_billing_details: UserBillingDetails = await _db.get_user_billing_details( - get_database_engine(app), user_id=user_id + user_billing_details: UserBillingDetails = ( + await _users_repository.get_user_billing_details( + get_database_engine(app), user_id=user_id + ) ) _user_billing_country = pycountry.countries.lookup(user_billing_details.country) _user_billing_country_alpha_2_format = _user_billing_country.alpha_2 diff --git a/services/web/server/src/simcore_service_webserver/users/_db.py b/services/web/server/src/simcore_service_webserver/users/_users_repository.py similarity index 62% rename from services/web/server/src/simcore_service_webserver/users/_db.py rename to services/web/server/src/simcore_service_webserver/users/_users_repository.py index 2071034d2e6c..a2979fed4f4d 100644 --- a/services/web/server/src/simcore_service_webserver/users/_db.py +++ b/services/web/server/src/simcore_service_webserver/users/_users_repository.py @@ -2,9 +2,6 @@ import sqlalchemy as sa from aiohttp import web -from aiopg.sa.connection import SAConnection -from aiopg.sa.engine import Engine -from aiopg.sa.result import ResultProxy, RowProxy from models_library.users import GroupID, UserBillingDetails, UserID from simcore_postgres_database.models.groups import groups, user_to_groups from simcore_postgres_database.models.products import products @@ -16,11 +13,17 @@ GroupExtraPropertiesNotFoundError, GroupExtraPropertiesRepo, ) +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from simcore_postgres_database.utils_users import UsersRepo from simcore_service_webserver.users.exceptions import UserNotFoundError +from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from ..db.models import user_to_groups -from ..db.plugin import get_database_engine +from ..db.plugin import get_asyncpg_engine from .exceptions import BillingDetailsNotFoundError from .schemas import Permission @@ -28,47 +31,60 @@ async def get_user_or_raise( - engine: Engine, *, user_id: UserID, return_column_names: list[str] | None = _ALL -) -> RowProxy: + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + user_id: UserID, + return_column_names: list[str] | None = _ALL, +) -> Row: if return_column_names == _ALL: return_column_names = list(users.columns.keys()) assert return_column_names is not None # nosec assert set(return_column_names).issubset(users.columns.keys()) # nosec - async with engine.acquire() as conn: - row: RowProxy | None = await ( - await conn.execute( - sa.select(*(users.columns[name] for name in return_column_names)).where( - users.c.id == user_id - ) + async with pass_or_acquire_connection(engine, connection) as conn: + result = await conn.stream( + sa.select(*(users.columns[name] for name in return_column_names)).where( + users.c.id == user_id ) - ).first() + ) + row = await result.first() if row is None: raise UserNotFoundError(uid=user_id) return row -async def get_users_ids_in_group(conn: SAConnection, gid: GroupID) -> set[UserID]: - result: set[UserID] = set() - query_result = await conn.execute( - sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == gid) - ) - async for entry in query_result: - result.add(entry[0]) - return result +async def get_users_ids_in_group( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + group_id: GroupID, +) -> set[UserID]: + async with pass_or_acquire_connection(engine, connection) as conn: + result = await conn.stream( + sa.select(user_to_groups.c.uid).where(user_to_groups.c.gid == group_id) + ) + return {row.uid async for row in result} async def list_user_permissions( - app: web.Application, *, user_id: UserID, product_name: str + app: web.Application, + connection: AsyncConnection | None = None, + *, + user_id: UserID, + product_name: str, ) -> list[Permission]: override_services_specifications = Permission( name="override_services_specifications", allowed=False, ) with contextlib.suppress(GroupExtraPropertiesNotFoundError): - async with get_database_engine(app).acquire() as conn: + async with pass_or_acquire_connection( + get_asyncpg_engine(app), connection + ) as conn: user_group_extra_properties = ( + # TODO: adapt to asyncpg await GroupExtraPropertiesRepo.get_aggregated_properties_for_user( conn, user_id=user_id, product_name=product_name ) @@ -80,34 +96,43 @@ async def list_user_permissions( return [override_services_specifications] -async def do_update_expired_users(conn: SAConnection) -> list[UserID]: - result: ResultProxy = await conn.execute( - users.update() - .values(status=UserStatus.EXPIRED) - .where( - (users.c.expires_at.is_not(None)) - & (users.c.status == UserStatus.ACTIVE) - & (users.c.expires_at < sa.sql.func.now()) +async def do_update_expired_users( + engine: AsyncEngine, + connection: AsyncConnection | None = None, +) -> list[UserID]: + async with transaction_context(engine, connection) as conn: + result = await conn.stream( + users.update() + .values(status=UserStatus.EXPIRED) + .where( + (users.c.expires_at.is_not(None)) + & (users.c.status == UserStatus.ACTIVE) + & (users.c.expires_at < sa.sql.func.now()) + ) + .returning(users.c.id) ) - .returning(users.c.id) - ) - if rows := await result.fetchall(): - return [r.id for r in rows] - return [] + return [row.id async for row in result] async def update_user_status( - engine: Engine, *, user_id: UserID, new_status: UserStatus + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + user_id: UserID, + new_status: UserStatus, ): - async with engine.acquire() as conn: + async with transaction_context(engine, connection) as conn: await conn.execute( users.update().values(status=new_status).where(users.c.id == user_id) ) async def search_users_and_get_profile( - engine: Engine, *, email_like: str -) -> list[RowProxy]: + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + email_like: str, +) -> list[Row]: users_alias = sa.alias(users, name="users_alias") @@ -117,7 +142,7 @@ async def search_users_and_get_profile( .label("invited_by") ) - async with engine.acquire() as conn: + async with pass_or_acquire_connection(engine, connection) as conn: columns = ( users.c.first_name, users.c.last_name, @@ -159,12 +184,17 @@ async def search_users_and_get_profile( .where(users.c.email.like(email_like)) ) - result = await conn.execute(sa.union(left_outer_join, right_outer_join)) - return await result.fetchall() or [] + result = await conn.stream(sa.union(left_outer_join, right_outer_join)) + return [row async for row in result] -async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]: - async with engine.acquire() as conn: +async def get_user_products( + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + user_id: UserID, +) -> list[Row]: + async with pass_or_acquire_connection(engine, connection) as conn: product_name_subq = ( sa.select(products.c.name) .where(products.c.group_id == groups.c.gid) @@ -186,14 +216,19 @@ async def get_user_products(engine: Engine, user_id: UserID) -> list[RowProxy]: .where(users.c.id == user_id) .order_by(groups.c.gid) ) - result = await conn.execute(query) - return await result.fetchall() or [] + result = await conn.stream(query) + return [row async for row in result] async def new_user_details( - engine: Engine, email: str, created_by: UserID, **other_values + engine: AsyncEngine, + connection: AsyncConnection | None = None, + *, + email: str, + created_by: UserID, + **other_values, ) -> None: - async with engine.acquire() as conn: + async with transaction_context(engine, connection) as conn: await conn.execute( sa.insert(users_pre_registration_details).values( created_by=created_by, pre_email=email, **other_values @@ -202,13 +237,13 @@ async def new_user_details( async def get_user_billing_details( - engine: Engine, user_id: UserID + engine: AsyncEngine, connection: AsyncConnection | None = None, *, user_id: UserID ) -> UserBillingDetails: """ Raises: BillingDetailsNotFoundError """ - async with engine.acquire() as conn: + async with pass_or_acquire_connection(engine, connection) as conn: user_billing_details = await UsersRepo.get_billing_details(conn, user_id) if not user_billing_details: raise BillingDetailsNotFoundError(user_id=user_id) diff --git a/services/web/server/src/simcore_service_webserver/users/api.py b/services/web/server/src/simcore_service_webserver/users/api.py index 7fc2c138204e..e3bd94f1f29e 100644 --- a/services/web/server/src/simcore_service_webserver/users/api.py +++ b/services/web/server/src/simcore_service_webserver/users/api.py @@ -12,7 +12,6 @@ import simcore_postgres_database.errors as db_errors import sqlalchemy as sa from aiohttp import web -from aiopg.sa.engine import Engine from aiopg.sa.result import RowProxy from models_library.api_schemas_webserver.users import ( ProfileGet, @@ -29,11 +28,11 @@ GroupExtraPropertiesNotFoundError, ) -from ..db.plugin import get_database_engine +from ..db.plugin import get_asyncpg_engine, get_database_engine from ..groups.models import convert_groups_db_to_schema from ..login.storage import AsyncpgStorage, get_plugin_storage from ..security.api import clean_auth_policy_cache -from . import _db +from . import _users_repository from ._api import get_user_credentials, get_user_invoice_address, set_user_as_deleted from ._models import ToUserUpdateDB from ._preferences_api import get_frontend_user_preferences_aggregation @@ -216,8 +215,8 @@ async def get_user_name_and_email( Returns: (user, email) """ - row = await _db.get_user_or_raise( - get_database_engine(app), + row = await _users_repository.get_user_or_raise( + get_asyncpg_engine(app), user_id=_parse_as_user(user_id), return_column_names=["name", "email"], ) @@ -242,8 +241,8 @@ async def get_user_display_and_id_names( Raises: UserNotFoundError """ - row = await _db.get_user_or_raise( - get_database_engine(app), + row = await _users_repository.get_user_or_raise( + get_asyncpg_engine(app), user_id=_parse_as_user(user_id), return_column_names=["name", "email", "first_name", "last_name"], ) @@ -318,7 +317,9 @@ async def get_user(app: web.Application, user_id: UserID) -> dict[str, Any]: """ :raises UserNotFoundError: """ - row = await _db.get_user_or_raise(engine=get_database_engine(app), user_id=user_id) + row = await _users_repository.get_user_or_raise( + engine=get_asyncpg_engine(app), user_id=user_id + ) return dict(row) @@ -332,14 +333,13 @@ async def get_user_id_from_gid(app: web.Application, primary_gid: int) -> UserID async def get_users_in_group(app: web.Application, gid: GroupID) -> set[UserID]: - engine = get_database_engine(app) - async with engine.acquire() as conn: - return await _db.get_users_ids_in_group(conn, gid) + return await _users_repository.get_users_ids_in_group( + get_asyncpg_engine(app), group_id=gid + ) -async def update_expired_users(engine: Engine) -> list[UserID]: - async with engine.acquire() as conn: - return await _db.do_update_expired_users(conn) +async def update_expired_users(app: web.Application) -> list[UserID]: + return await _users_repository.do_update_expired_users(get_asyncpg_engine(app)) assert set_user_as_deleted # nosec diff --git a/services/web/server/tests/unit/with_dbs/03/test_users_api.py b/services/web/server/tests/unit/with_dbs/03/test_users_api.py index 89b5ddea4747..d43b09f4f114 100644 --- a/services/web/server/tests/unit/with_dbs/03/test_users_api.py +++ b/services/web/server/tests/unit/with_dbs/03/test_users_api.py @@ -11,7 +11,6 @@ from pytest_simcore.helpers.monkeypatch_envs import EnvVarsDict, setenvs_from_dict from pytest_simcore.helpers.webserver_login import NewUser from servicelib.aiohttp import status -from servicelib.aiohttp.application_keys import APP_AIOPG_ENGINE_KEY from simcore_postgres_database.models.users import UserStatus from simcore_service_webserver.users.api import ( get_user_name_and_email, @@ -67,7 +66,7 @@ async def _rq_login(): await assert_status(r1, status.HTTP_200_OK) # apply update - expired = await update_expired_users(client.app[APP_AIOPG_ENGINE_KEY]) + expired = await update_expired_users(client.app) if has_expired: assert expired == [user["id"]] else: