diff --git a/services/web/server/src/simcore_service_webserver/folders/_folders_api.py b/services/web/server/src/simcore_service_webserver/folders/_folders_api.py index 0344124abb6..043527d2def 100644 --- a/services/web/server/src/simcore_service_webserver/folders/_folders_api.py +++ b/services/web/server/src/simcore_service_webserver/folders/_folders_api.py @@ -262,7 +262,7 @@ async def update_folder( folder_db = await folders_db.update( app, - folder_id=folder_id, + folders_id_or_ids=folder_id, name=name, parent_folder_id=parent_folder_id, product_name=product_name, diff --git a/services/web/server/src/simcore_service_webserver/folders/_folders_db.py b/services/web/server/src/simcore_service_webserver/folders/_folders_db.py index 0ee44c17199..561bcb64c9e 100644 --- a/services/web/server/src/simcore_service_webserver/folders/_folders_db.py +++ b/services/web/server/src/simcore_service_webserver/folders/_folders_db.py @@ -19,11 +19,16 @@ from simcore_postgres_database.models.folders_v2 import folders_v2 from simcore_postgres_database.models.projects import projects from simcore_postgres_database.models.projects_to_folders import projects_to_folders +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from sqlalchemy import func +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.orm import aliased from sqlalchemy.sql import asc, desc, select -from ..db.plugin import get_database_engine +from ..db.plugin import get_asyncpg_engine from .errors import FolderAccessForbiddenError, FolderNotFoundError _logger = logging.getLogger(__name__) @@ -55,6 +60,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]: async def create( app: web.Application, + connection: AsyncConnection | None = None, *, created_by_gid: GroupID, folder_name: str, @@ -67,8 +73,8 @@ async def create( user_id is not None and workspace_id is not None ), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one." - async with get_database_engine(app).acquire() as conn: - result = await conn.execute( + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream( folders_v2.insert() .values( name=folder_name, @@ -88,6 +94,7 @@ async def create( async def list_( app: web.Application, + connection: AsyncConnection | None = None, *, content_of_folder_id: FolderID | None, user_id: UserID | None, @@ -142,18 +149,17 @@ async def list_( list_query = base_query.order_by(desc(getattr(folders_v2.c, order_by.field))) list_query = list_query.offset(offset).limit(limit) - async with get_database_engine(app).acquire() as conn: - count_result = await conn.execute(count_query) - total_count = await count_result.scalar() + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + total_count = await conn.scalar(count_query) - result = await conn.execute(list_query) - rows = await result.fetchall() or [] - results: list[FolderDB] = [FolderDB.from_orm(row) for row in rows] - return cast(int, total_count), results + result = await conn.stream(list_query) + folders: list[FolderDB] = [FolderDB.from_orm(row) async for row in result] + return cast(int, total_count), folders async def get( app: web.Application, + connection: AsyncConnection | None = None, *, folder_id: FolderID, product_name: ProductName, @@ -167,8 +173,8 @@ async def get( ) ) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(query) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(query) row = await result.first() if row is None: raise FolderAccessForbiddenError( @@ -179,6 +185,7 @@ async def get( async def get_for_user_or_workspace( app: web.Application, + connection: AsyncConnection | None = None, *, folder_id: FolderID, product_name: ProductName, @@ -203,8 +210,8 @@ async def get_for_user_or_workspace( else: query = query.where(folders_v2.c.workspace_id == workspace_id) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(query) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(query) row = await result.first() if row is None: raise FolderAccessForbiddenError( @@ -213,8 +220,10 @@ async def get_for_user_or_workspace( return FolderDB.from_orm(row) -async def _update_impl( +async def update( app: web.Application, + connection: AsyncConnection | None = None, + *, folders_id_or_ids: FolderID | set[FolderID], product_name: ProductName, # updatable columns @@ -247,64 +256,22 @@ async def _update_impl( # single-update query = query.where(folders_v2.c.folder_id == folders_id_or_ids) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(query) + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(query) row = await result.first() if row is None: raise FolderNotFoundError(reason=f"Folder {folders_id_or_ids} not found.") return FolderDB.from_orm(row) -async def update_batch( - app: web.Application, - *folder_id: FolderID, - product_name: ProductName, - # updatable columns - name: str | UnSet = _unset, - parent_folder_id: FolderID | None | UnSet = _unset, - trashed_at: datetime | None | UnSet = _unset, - trashed_explicitly: bool | UnSet = _unset, -) -> FolderDB: - return await _update_impl( - app=app, - folders_id_or_ids=set(folder_id), - product_name=product_name, - name=name, - parent_folder_id=parent_folder_id, - trashed_at=trashed_at, - trashed_explicitly=trashed_explicitly, - ) - - -async def update( - app: web.Application, - *, - folder_id: FolderID, - product_name: ProductName, - # updatable columns - name: str | UnSet = _unset, - parent_folder_id: FolderID | None | UnSet = _unset, - trashed_at: datetime | None | UnSet = _unset, - trashed_explicitly: bool | UnSet = _unset, -) -> FolderDB: - return await _update_impl( - app=app, - folders_id_or_ids=folder_id, - product_name=product_name, - name=name, - parent_folder_id=parent_folder_id, - trashed_at=trashed_at, - trashed_explicitly=trashed_explicitly, - ) - - async def delete_recursively( app: web.Application, + connection: AsyncConnection | None = None, *, folder_id: FolderID, product_name: ProductName, ) -> None: - async with get_database_engine(app).acquire() as conn, conn.begin(): + async with transaction_context(get_asyncpg_engine(app), connection) as conn: # Step 1: Define the base case for the recursive CTE base_query = select( folders_v2.c.folder_id, folders_v2.c.parent_folder_id @@ -330,10 +297,9 @@ async def delete_recursively( # Step 4: Execute the query to get all descendants final_query = select(folder_hierarchy_cte) - result = await conn.execute(final_query) - rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)] - await result.fetchall() or [] - ) + result = await conn.stream(final_query) + # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)] + rows = [row async for row in result] # Sort folders so that child folders come first sorted_folders = sorted( @@ -347,6 +313,7 @@ async def delete_recursively( async def get_projects_recursively_only_if_user_is_owner( app: web.Application, + connection: AsyncConnection | None = None, *, folder_id: FolderID, private_workspace_user_id_or_none: UserID | None, @@ -361,7 +328,8 @@ async def get_projects_recursively_only_if_user_is_owner( or the `users_to_groups` table for private workspace projects. """ - async with get_database_engine(app).acquire() as conn, conn.begin(): + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + # Step 1: Define the base case for the recursive CTE base_query = select( folders_v2.c.folder_id, folders_v2.c.parent_folder_id @@ -370,6 +338,7 @@ async def get_projects_recursively_only_if_user_is_owner( & (folders_v2.c.product_name == product_name) ) folder_hierarchy_cte = base_query.cte(name="folder_hierarchy", recursive=True) + # Step 2: Define the recursive case folder_alias = aliased(folders_v2) recursive_query = select( @@ -380,16 +349,15 @@ async def get_projects_recursively_only_if_user_is_owner( folder_alias.c.parent_folder_id == folder_hierarchy_cte.c.folder_id, ) ) + # Step 3: Combine base and recursive cases into a CTE folder_hierarchy_cte = folder_hierarchy_cte.union_all(recursive_query) + # Step 4: Execute the query to get all descendants final_query = select(folder_hierarchy_cte) - result = await conn.execute(final_query) - rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)] - await result.fetchall() or [] - ) - - folder_ids = [item[0] for item in rows] + result = await conn.stream(final_query) + # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)] + folder_ids = [item[0] async for item in result] query = ( select(projects_to_folders.c.project_uuid) @@ -402,19 +370,19 @@ async def get_projects_recursively_only_if_user_is_owner( if private_workspace_user_id_or_none is not None: query = query.where(projects.c.prj_owner == user_id) - result = await conn.execute(query) - - rows = await result.fetchall() or [] - return [ProjectID(row[0]) for row in rows] + result = await conn.stream(query) + return [ProjectID(row[0]) async for row in result] async def get_folders_recursively( app: web.Application, + connection: AsyncConnection | None = None, *, folder_id: FolderID, product_name: ProductName, ) -> list[FolderID]: - async with get_database_engine(app).acquire() as conn, conn.begin(): + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + # Step 1: Define the base case for the recursive CTE base_query = select( folders_v2.c.folder_id, folders_v2.c.parent_folder_id @@ -440,9 +408,5 @@ async def get_folders_recursively( # Step 4: Execute the query to get all descendants final_query = select(folder_hierarchy_cte) - result = await conn.execute(final_query) - rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)] - await result.fetchall() or [] - ) - - return [FolderID(row[0]) for row in rows] + result = await conn.stream(final_query) + return [FolderID(row[0]) async for row in result] diff --git a/services/web/server/src/simcore_service_webserver/folders/_trash_api.py b/services/web/server/src/simcore_service_webserver/folders/_trash_api.py index 1cad0415161..b3e1823369a 100644 --- a/services/web/server/src/simcore_service_webserver/folders/_trash_api.py +++ b/services/web/server/src/simcore_service_webserver/folders/_trash_api.py @@ -7,7 +7,10 @@ from models_library.products import ProductName from models_library.projects import ProjectID from models_library.users import UserID +from simcore_postgres_database.utils_repos import transaction_context +from sqlalchemy.ext.asyncio import AsyncConnection +from ..db.plugin import get_asyncpg_engine from ..projects._trash_api import trash_project, untrash_project from ..workspaces.api import check_user_workspace_access from . import _folders_db @@ -55,6 +58,7 @@ async def _check_exists_and_access( async def _folders_db_update( app: web.Application, + connection: AsyncConnection | None = None, *, product_name: ProductName, folder_id: FolderID, @@ -63,7 +67,8 @@ async def _folders_db_update( # EXPLICIT un/trash await _folders_db.update( app, - folder_id=folder_id, + connection, + folders_id_or_ids=folder_id, product_name=product_name, trashed_at=trashed_at, trashed_explicitly=trashed_at is not None, @@ -73,15 +78,16 @@ async def _folders_db_update( child_folders: set[FolderID] = { f for f in await _folders_db.get_folders_recursively( - app, folder_id=folder_id, product_name=product_name + app, connection, folder_id=folder_id, product_name=product_name ) if f != folder_id } if child_folders: - await _folders_db.update_batch( + await _folders_db.update( app, - *child_folders, + connection, + folders_id_or_ids=child_folders, product_name=product_name, trashed_at=trashed_at, trashed_explicitly=False, @@ -104,40 +110,40 @@ async def trash_folder( # Trash trashed_at = arrow.utcnow().datetime - _logger.debug( - "TODO: Unit of work for all folders and projects and fails if force_stop_first=%s is False", - force_stop_first, - ) - - # 1. Trash folder and children - await _folders_db_update( - app, - folder_id=folder_id, - product_name=product_name, - trashed_at=trashed_at, - ) - - # 2. Trash all child projects that I am an owner - child_projects: list[ - ProjectID - ] = await _folders_db.get_projects_recursively_only_if_user_is_owner( - app, - folder_id=folder_id, - private_workspace_user_id_or_none=user_id if workspace_is_private else None, - user_id=user_id, - product_name=product_name, - ) + async with transaction_context(get_asyncpg_engine(app)) as connection: - for project_id in child_projects: - await trash_project( + # 1. Trash folder and children + await _folders_db_update( app, + connection, + folder_id=folder_id, product_name=product_name, + trashed_at=trashed_at, + ) + + # 2. Trash all child projects that I am an owner + child_projects: list[ + ProjectID + ] = await _folders_db.get_projects_recursively_only_if_user_is_owner( + app, + connection, + folder_id=folder_id, + private_workspace_user_id_or_none=user_id if workspace_is_private else None, user_id=user_id, - project_id=project_id, - force_stop_first=force_stop_first, - explicit=False, + product_name=product_name, ) + for project_id in child_projects: + await trash_project( + app, + # NOTE: this needs to be included in the unit-of-work, i.e. connection, + product_name=product_name, + user_id=user_id, + project_id=project_id, + force_stop_first=force_stop_first, + explicit=False, + ) + async def untrash_folder( app: web.Application, diff --git a/services/web/server/src/simcore_service_webserver/workspaces/_groups_db.py b/services/web/server/src/simcore_service_webserver/workspaces/_groups_db.py index daeba51ae80..019ec5530b0 100644 --- a/services/web/server/src/simcore_service_webserver/workspaces/_groups_db.py +++ b/services/web/server/src/simcore_service_webserver/workspaces/_groups_db.py @@ -13,10 +13,15 @@ from simcore_postgres_database.models.workspaces_access_rights import ( workspaces_access_rights, ) +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from sqlalchemy import func, literal_column +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.sql import select -from ..db.plugin import get_database_engine +from ..db.plugin import get_asyncpg_engine from .errors import WorkspaceGroupNotFoundError _logger = logging.getLogger(__name__) @@ -41,15 +46,16 @@ class Config: async def create_workspace_group( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, group_id: GroupID, - *, read: bool, write: bool, delete: bool, ) -> WorkspaceGroupGetDB: - async with get_database_engine(app).acquire() as conn: - result = await conn.execute( + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream( workspaces_access_rights.insert() .values( workspace_id=workspace_id, @@ -68,6 +74,8 @@ async def create_workspace_group( async def list_workspace_groups( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, ) -> list[WorkspaceGroupGetDB]: stmt = ( @@ -83,14 +91,15 @@ async def list_workspace_groups( .where(workspaces_access_rights.c.workspace_id == workspace_id) ) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(stmt) - rows = await result.fetchall() or [] - return [WorkspaceGroupGetDB.from_orm(row) for row in rows] + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(stmt) + return [WorkspaceGroupGetDB.from_orm(row) async for row in result] async def get_workspace_group( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, group_id: GroupID, ) -> WorkspaceGroupGetDB: @@ -110,8 +119,8 @@ async def get_workspace_group( ) ) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(stmt) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(stmt) row = await result.first() if row is None: raise WorkspaceGroupNotFoundError( @@ -122,15 +131,16 @@ async def get_workspace_group( async def update_workspace_group( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, group_id: GroupID, - *, read: bool, write: bool, delete: bool, ) -> WorkspaceGroupGetDB: - async with get_database_engine(app).acquire() as conn: - result = await conn.execute( + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream( workspaces_access_rights.update() .values( read=read, @@ -153,10 +163,12 @@ async def update_workspace_group( async def delete_workspace_group( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, group_id: GroupID, ) -> None: - async with get_database_engine(app).acquire() as conn: + async with transaction_context(get_asyncpg_engine(app), connection) as conn: await conn.execute( workspaces_access_rights.delete().where( (workspaces_access_rights.c.workspace_id == workspace_id) diff --git a/services/web/server/src/simcore_service_webserver/workspaces/_workspaces_db.py b/services/web/server/src/simcore_service_webserver/workspaces/_workspaces_db.py index 23de15c3b19..a959843a969 100644 --- a/services/web/server/src/simcore_service_webserver/workspaces/_workspaces_db.py +++ b/services/web/server/src/simcore_service_webserver/workspaces/_workspaces_db.py @@ -22,11 +22,16 @@ from simcore_postgres_database.models.workspaces_access_rights import ( workspaces_access_rights, ) +from simcore_postgres_database.utils_repos import ( + pass_or_acquire_connection, + transaction_context, +) from sqlalchemy import asc, desc, func from sqlalchemy.dialects.postgresql import BOOLEAN, INTEGER +from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.sql import Subquery, select -from ..db.plugin import get_database_engine +from ..db.plugin import get_asyncpg_engine from .errors import WorkspaceAccessForbiddenError, WorkspaceNotFoundError _logger = logging.getLogger(__name__) @@ -45,14 +50,16 @@ async def create_workspace( app: web.Application, + connection: AsyncConnection | None = None, + *, product_name: ProductName, owner_primary_gid: GroupID, name: str, description: str | None, thumbnail: str | None, ) -> WorkspaceDB: - async with get_database_engine(app).acquire() as conn: - result = await conn.execute( + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream( workspaces.insert() .values( name=name, @@ -69,7 +76,7 @@ async def create_workspace( return WorkspaceDB.from_orm(row) -access_rights_subquery = ( +_access_rights_subquery = ( select( workspaces_access_rights.c.workspace_id, func.jsonb_object_agg( @@ -116,6 +123,7 @@ def _create_my_access_rights_subquery(user_id: UserID) -> Subquery: async def list_workspaces_for_user( app: web.Application, + connection: AsyncConnection | None = None, *, user_id: UserID, product_name: ProductName, @@ -128,11 +136,11 @@ async def list_workspaces_for_user( base_query = ( select( *_SELECTION_ARGS, - access_rights_subquery.c.access_rights, + _access_rights_subquery.c.access_rights, my_access_rights_subquery.c.my_access_rights, ) .select_from( - workspaces.join(access_rights_subquery).join(my_access_rights_subquery) + workspaces.join(_access_rights_subquery).join(my_access_rights_subquery) ) .where(workspaces.c.product_name == product_name) ) @@ -148,21 +156,21 @@ async def list_workspaces_for_user( list_query = base_query.order_by(desc(getattr(workspaces.c, order_by.field))) list_query = list_query.offset(offset).limit(limit) - async with get_database_engine(app).acquire() as conn: - count_result = await conn.execute(count_query) - total_count = await count_result.scalar() + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + total_count = await conn.scalar(count_query) - result = await conn.execute(list_query) - rows = await result.fetchall() or [] - results: list[UserWorkspaceAccessRightsDB] = [ - UserWorkspaceAccessRightsDB.from_orm(row) for row in rows + result = await conn.stream(list_query) + items: list[UserWorkspaceAccessRightsDB] = [ + UserWorkspaceAccessRightsDB.from_orm(row) async for row in result ] - return cast(int, total_count), results + return cast(int, total_count), items async def get_workspace_for_user( app: web.Application, + connection: AsyncConnection | None = None, + *, user_id: UserID, workspace_id: WorkspaceID, product_name: ProductName, @@ -172,11 +180,11 @@ async def get_workspace_for_user( base_query = ( select( *_SELECTION_ARGS, - access_rights_subquery.c.access_rights, + _access_rights_subquery.c.access_rights, my_access_rights_subquery.c.my_access_rights, ) .select_from( - workspaces.join(access_rights_subquery).join(my_access_rights_subquery) + workspaces.join(_access_rights_subquery).join(my_access_rights_subquery) ) .where( (workspaces.c.workspace_id == workspace_id) @@ -184,8 +192,8 @@ async def get_workspace_for_user( ) ) - async with get_database_engine(app).acquire() as conn: - result = await conn.execute(base_query) + async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream(base_query) row = await result.first() if row is None: raise WorkspaceAccessForbiddenError( @@ -196,14 +204,16 @@ async def get_workspace_for_user( async def update_workspace( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, name: str, description: str | None, thumbnail: str | None, product_name: ProductName, ) -> WorkspaceDB: - async with get_database_engine(app).acquire() as conn: - result = await conn.execute( + async with transaction_context(get_asyncpg_engine(app), connection) as conn: + result = await conn.stream( workspaces.update() .values( name=name, @@ -225,10 +235,12 @@ async def update_workspace( async def delete_workspace( app: web.Application, + connection: AsyncConnection | None = None, + *, workspace_id: WorkspaceID, product_name: ProductName, ) -> None: - async with get_database_engine(app).acquire() as conn: + async with transaction_context(get_asyncpg_engine(app), connection) as conn: await conn.execute( workspaces.delete().where( (workspaces.c.workspace_id == workspace_id)