diff --git a/masterbase/app.py b/masterbase/app.py index 0c124d8..ba001e9 100644 --- a/masterbase/app.py +++ b/masterbase/app.py @@ -13,7 +13,7 @@ from litestar.exceptions import NotAuthorizedException, PermissionDeniedException from litestar.handlers import WebsocketListener from litestar.handlers.base import BaseRouteHandler -from litestar.response import Redirect +from litestar.response import Redirect, Stream from sqlalchemy import Engine, create_engine from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine @@ -163,11 +163,12 @@ def late_bytes(request: Request, api_key: str, data: dict[str, str]) -> dict[str @get("/demodata", guards=[valid_key_guard], sync_to_thread=False) -def demodata(request: Request, api_key: str, session_id: str) -> bytes: +def demodata(request: Request, api_key: str, session_id: str) -> Stream: """Return the demo.""" engine = request.app.state.engine - data = demodata_helper(engine, api_key, session_id) - return data + bytestream = demodata_helper(engine, api_key, session_id) + headers = {"Content-Disposition": f'attachment; filename="{session_id}.dem"'} + return Stream(bytestream, media_type=MediaType.TEXT, headers=headers) class DemoHandler(WebsocketListener): @@ -333,7 +334,7 @@ def provision_handler(request: Request) -> str: app = Litestar( on_startup=[get_db_connection, get_async_db_connection], - route_handlers=[session_id, close_session, DemoHandler, provision, provision_handler, late_bytes], + route_handlers=[session_id, close_session, DemoHandler, provision, provision_handler, late_bytes, demodata], on_shutdown=[close_db_connection, close_async_db_connection], ) diff --git a/masterbase/lib.py b/masterbase/lib.py index 6e5c204..f39fbef 100644 --- a/masterbase/lib.py +++ b/masterbase/lib.py @@ -2,7 +2,7 @@ import os from datetime import datetime, timezone -from typing import IO +from typing import IO, Generator from uuid import uuid4 from xml.etree import ElementTree @@ -21,7 +21,7 @@ def make_db_uri(is_async: bool = False) -> str: user = os.environ["POSTGRES_USER"] password = os.environ["POSTGRES_PASSWORD"] host = os.environ.get("POSTGRES_HOST", "localhost") - port = os.environ.get("POSTGRES_PORT", "8050") + port = os.environ.get("POSTGRES_PORT", "5432") prefix = "postgresql" if is_async: prefix = f"{prefix}+asyncpg" @@ -263,24 +263,20 @@ def late_bytes_helper(engine: Engine, api_key: str, late_bytes: bytes, current_t conn.commit() -def demodata_helper(engine: Engine, api_key: str, session_id: str) -> bytes: - """Return demo data as bytes.""" +def demodata_helper(engine: Engine, api_key: str, session_id: str) -> Generator[bytes, None, None]: + """Yield demo data page by page.""" sql = """ - SELECT loid, STRING_AGG(data, '' ORDER BY pageno) AS all_data + SELECT pageno, data FROM pg_largeobject JOIN demo_sessions demo ON demo.demo_oid = pg_largeobject.loid WHERE demo.session_id = :session_id - GROUP BY loid; + ORDER BY pageno; """ with engine.connect() as conn: result = conn.execute(sa.text(sql), dict(session_id=session_id)) - row = result.fetchone() - - if row is not None: - return row[0].encode("utf-8") if row[0] else b"" - - return b"" + for row in result: + yield row[1].tobytes() def check_steam_id_has_api_key(engine: Engine, steam_id: str) -> str | None: