Skip to content

Commit

Permalink
♻️ Migrates folders and workspaces repositories to asyncpg (ITISFound…
Browse files Browse the repository at this point in the history
  • Loading branch information
pcrespov authored Nov 11, 2024
1 parent 2af7f21 commit 8f182d3
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ async def update_folder(

folder_db = await folders_db.update(
app,
folder_id=folder_id,
folders_id_or_ids=folder_id,
name=name,
parent_folder_id=parent_folder_id,
product_name=product_name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,16 @@
from simcore_postgres_database.models.folders_v2 import folders_v2
from simcore_postgres_database.models.projects import projects
from simcore_postgres_database.models.projects_to_folders import projects_to_folders
from simcore_postgres_database.utils_repos import (
pass_or_acquire_connection,
transaction_context,
)
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncConnection
from sqlalchemy.orm import aliased
from sqlalchemy.sql import asc, desc, select

from ..db.plugin import get_database_engine
from ..db.plugin import get_asyncpg_engine
from .errors import FolderAccessForbiddenError, FolderNotFoundError

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -55,6 +60,7 @@ def as_dict_exclude_unset(**params) -> dict[str, Any]:

async def create(
app: web.Application,
connection: AsyncConnection | None = None,
*,
created_by_gid: GroupID,
folder_name: str,
Expand All @@ -67,8 +73,8 @@ async def create(
user_id is not None and workspace_id is not None
), "Both user_id and workspace_id cannot be provided at the same time. Please provide only one."

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(
folders_v2.insert()
.values(
name=folder_name,
Expand All @@ -88,6 +94,7 @@ async def create(

async def list_(
app: web.Application,
connection: AsyncConnection | None = None,
*,
content_of_folder_id: FolderID | None,
user_id: UserID | None,
Expand Down Expand Up @@ -142,18 +149,17 @@ async def list_(
list_query = base_query.order_by(desc(getattr(folders_v2.c, order_by.field)))
list_query = list_query.offset(offset).limit(limit)

async with get_database_engine(app).acquire() as conn:
count_result = await conn.execute(count_query)
total_count = await count_result.scalar()
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
total_count = await conn.scalar(count_query)

result = await conn.execute(list_query)
rows = await result.fetchall() or []
results: list[FolderDB] = [FolderDB.from_orm(row) for row in rows]
return cast(int, total_count), results
result = await conn.stream(list_query)
folders: list[FolderDB] = [FolderDB.from_orm(row) async for row in result]
return cast(int, total_count), folders


async def get(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
Expand All @@ -167,8 +173,8 @@ async def get(
)
)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderAccessForbiddenError(
Expand All @@ -179,6 +185,7 @@ async def get(

async def get_for_user_or_workspace(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
Expand All @@ -203,8 +210,8 @@ async def get_for_user_or_workspace(
else:
query = query.where(folders_v2.c.workspace_id == workspace_id)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderAccessForbiddenError(
Expand All @@ -213,8 +220,10 @@ async def get_for_user_or_workspace(
return FolderDB.from_orm(row)


async def _update_impl(
async def update(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folders_id_or_ids: FolderID | set[FolderID],
product_name: ProductName,
# updatable columns
Expand Down Expand Up @@ -247,64 +256,22 @@ async def _update_impl(
# single-update
query = query.where(folders_v2.c.folder_id == folders_id_or_ids)

async with get_database_engine(app).acquire() as conn:
result = await conn.execute(query)
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
result = await conn.stream(query)
row = await result.first()
if row is None:
raise FolderNotFoundError(reason=f"Folder {folders_id_or_ids} not found.")
return FolderDB.from_orm(row)


async def update_batch(
app: web.Application,
*folder_id: FolderID,
product_name: ProductName,
# updatable columns
name: str | UnSet = _unset,
parent_folder_id: FolderID | None | UnSet = _unset,
trashed_at: datetime | None | UnSet = _unset,
trashed_explicitly: bool | UnSet = _unset,
) -> FolderDB:
return await _update_impl(
app=app,
folders_id_or_ids=set(folder_id),
product_name=product_name,
name=name,
parent_folder_id=parent_folder_id,
trashed_at=trashed_at,
trashed_explicitly=trashed_explicitly,
)


async def update(
app: web.Application,
*,
folder_id: FolderID,
product_name: ProductName,
# updatable columns
name: str | UnSet = _unset,
parent_folder_id: FolderID | None | UnSet = _unset,
trashed_at: datetime | None | UnSet = _unset,
trashed_explicitly: bool | UnSet = _unset,
) -> FolderDB:
return await _update_impl(
app=app,
folders_id_or_ids=folder_id,
product_name=product_name,
name=name,
parent_folder_id=parent_folder_id,
trashed_at=trashed_at,
trashed_explicitly=trashed_explicitly,
)


async def delete_recursively(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
) -> None:
async with get_database_engine(app).acquire() as conn, conn.begin():
async with transaction_context(get_asyncpg_engine(app), connection) as conn:
# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -330,10 +297,9 @@ async def delete_recursively(

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)
result = await conn.stream(final_query)
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
rows = [row async for row in result]

# Sort folders so that child folders come first
sorted_folders = sorted(
Expand All @@ -347,6 +313,7 @@ async def delete_recursively(

async def get_projects_recursively_only_if_user_is_owner(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
private_workspace_user_id_or_none: UserID | None,
Expand All @@ -361,7 +328,8 @@ async def get_projects_recursively_only_if_user_is_owner(
or the `users_to_groups` table for private workspace projects.
"""

async with get_database_engine(app).acquire() as conn, conn.begin():
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:

# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -370,6 +338,7 @@ async def get_projects_recursively_only_if_user_is_owner(
& (folders_v2.c.product_name == product_name)
)
folder_hierarchy_cte = base_query.cte(name="folder_hierarchy", recursive=True)

# Step 2: Define the recursive case
folder_alias = aliased(folders_v2)
recursive_query = select(
Expand All @@ -380,16 +349,15 @@ async def get_projects_recursively_only_if_user_is_owner(
folder_alias.c.parent_folder_id == folder_hierarchy_cte.c.folder_id,
)
)

# Step 3: Combine base and recursive cases into a CTE
folder_hierarchy_cte = folder_hierarchy_cte.union_all(recursive_query)

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)

folder_ids = [item[0] for item in rows]
result = await conn.stream(final_query)
# list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
folder_ids = [item[0] async for item in result]

query = (
select(projects_to_folders.c.project_uuid)
Expand All @@ -402,19 +370,19 @@ async def get_projects_recursively_only_if_user_is_owner(
if private_workspace_user_id_or_none is not None:
query = query.where(projects.c.prj_owner == user_id)

result = await conn.execute(query)

rows = await result.fetchall() or []
return [ProjectID(row[0]) for row in rows]
result = await conn.stream(query)
return [ProjectID(row[0]) async for row in result]


async def get_folders_recursively(
app: web.Application,
connection: AsyncConnection | None = None,
*,
folder_id: FolderID,
product_name: ProductName,
) -> list[FolderID]:
async with get_database_engine(app).acquire() as conn, conn.begin():
async with pass_or_acquire_connection(get_asyncpg_engine(app), connection) as conn:

# Step 1: Define the base case for the recursive CTE
base_query = select(
folders_v2.c.folder_id, folders_v2.c.parent_folder_id
Expand All @@ -440,9 +408,5 @@ async def get_folders_recursively(

# Step 4: Execute the query to get all descendants
final_query = select(folder_hierarchy_cte)
result = await conn.execute(final_query)
rows = ( # list of tuples [(folder_id, parent_folder_id), ...] ex. [(1, None), (2, 1)]
await result.fetchall() or []
)

return [FolderID(row[0]) for row in rows]
result = await conn.stream(final_query)
return [FolderID(row[0]) async for row in result]
Loading

0 comments on commit 8f182d3

Please sign in to comment.