From a3b030222eac4c2ba35ddbed34e05d994b671745 Mon Sep 17 00:00:00 2001 From: Lucas ONeil Date: Fri, 20 Dec 2024 10:28:02 -0800 Subject: [PATCH] Buffer messages to wait for reconnect Signed-off-by: Lucas ONeil --- oidc-controller/api/routers/acapy_handler.py | 17 ++--- oidc-controller/api/routers/oidc.py | 7 +- .../api/routers/presentation_request.py | 9 +-- oidc-controller/api/routers/socketio.py | 75 ++++++++++++++++--- .../api/templates/verified_credentials.html | 10 +++ 5 files changed, 85 insertions(+), 33 deletions(-) diff --git a/oidc-controller/api/routers/acapy_handler.py b/oidc-controller/api/routers/acapy_handler.py index a4a8a534..74714ac8 100644 --- a/oidc-controller/api/routers/acapy_handler.py +++ b/oidc-controller/api/routers/acapy_handler.py @@ -11,7 +11,7 @@ from ..db.session import get_db from ..core.config import settings -from ..routers.socketio import sio, connections_reload +from ..routers.socketio import buffered_emit, connections_reload logger: structlog.typing.FilteringBoundLogger = structlog.getLogger(__name__) @@ -39,9 +39,6 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db # Get the saved websocket session pid = str(auth_session.id) - connections = connections_reload() - sid = connections.get(pid) - logger.debug(f"sid: {sid} found for pid: {pid}") if webhook_body["state"] == "presentation-received": logger.info("presentation-received") @@ -51,12 +48,10 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db if webhook_body["verified"] == "true": auth_session.proof_status = AuthSessionState.VERIFIED auth_session.presentation_exchange = webhook_body["by_format"] - if sid: - await sio.emit("status", {"status": "verified"}, to=sid) + await buffered_emit("status", {"status": "verified"}, to_pid=pid) else: auth_session.proof_status = AuthSessionState.FAILED - if sid: - await sio.emit("status", {"status": "failed"}, to=sid) + await buffered_emit("status", {"status": "failed"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) @@ -67,8 +62,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db logger.info("ABANDONED") logger.info(webhook_body["error_msg"]) auth_session.proof_status = AuthSessionState.ABANDONED - if sid: - await sio.emit("status", {"status": "abandoned"}, to=sid) + await buffered_emit("status", {"status": "abandoned"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) @@ -93,8 +87,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db ): logger.info("EXPIRED") auth_session.proof_status = AuthSessionState.EXPIRED - if sid: - await sio.emit("status", {"status": "expired"}, to=sid) + await buffered_emit("status", {"status": "expired"}, to_pid=pid) await AuthSessionCRUD(db).patch( str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) diff --git a/oidc-controller/api/routers/oidc.py b/oidc-controller/api/routers/oidc.py index 2d4d7073..84b589d5 100644 --- a/oidc-controller/api/routers/oidc.py +++ b/oidc-controller/api/routers/oidc.py @@ -31,7 +31,7 @@ from ..db.session import get_db # Access to the websocket -from ..routers.socketio import connections_reload, sio +from ..routers.socketio import buffered_emit, connections_reload from ..verificationConfigs.crud import VerificationConfigCRUD from ..verificationConfigs.helpers import VariableSubstitutionError @@ -58,8 +58,6 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)): auth_session = await AuthSessionCRUD(db).get(pid) pid = str(auth_session.id) - connections = connections_reload() - sid = connections.get(pid) """ Check if proof is expired. But only if the proof has not been started. @@ -75,8 +73,7 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)): str(auth_session.id), AuthSessionPatch(**auth_session.model_dump()) ) # Send message through the websocket. - if sid: - await sio.emit("status", {"status": "expired"}, to=sid) + await buffered_emit("status", {"status": "expired"}, to_pid=pid) return {"proof_status": auth_session.proof_status} diff --git a/oidc-controller/api/routers/presentation_request.py b/oidc-controller/api/routers/presentation_request.py index 9bc753bb..8331efeb 100644 --- a/oidc-controller/api/routers/presentation_request.py +++ b/oidc-controller/api/routers/presentation_request.py @@ -9,7 +9,7 @@ from ..authSessions.models import AuthSession, AuthSessionState from ..core.config import settings -from ..routers.socketio import sio, connections_reload +from ..routers.socketio import buffered_emit, connections_reload from ..routers.oidc import gen_deep_link from ..db.session import get_db @@ -49,16 +49,11 @@ async def send_connectionless_proof_req( pres_exch_id ) - # Get the websocket session - connections = connections_reload() - sid = connections.get(str(auth_session.id)) - # If the qrcode has been scanned, toggle the verified flag if auth_session.proof_status is AuthSessionState.NOT_STARTED: auth_session.proof_status = AuthSessionState.PENDING await AuthSessionCRUD(db).patch(auth_session.id, auth_session) - if sid: - await sio.emit("status", {"status": "pending"}, to=sid) + await buffered_emit("status", {"status": "pending"}, to_pid=auth_session.id) msg = auth_session.presentation_request_msg diff --git a/oidc-controller/api/routers/socketio.py b/oidc-controller/api/routers/socketio.py index a0d6a3d2..6a08e9f7 100644 --- a/oidc-controller/api/routers/socketio.py +++ b/oidc-controller/api/routers/socketio.py @@ -1,13 +1,14 @@ import socketio # For using websockets import logging +import time logger = logging.getLogger(__name__) - connections = {} +message_buffers = {} +buffer_timeout = 60 # Timeout in seconds sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") - sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io") @@ -18,18 +19,74 @@ async def connect(sid, socket): @sio.event async def initialize(sid, data): - global connections - # Store websocket session matched to the presentation exchange id - connections[data.get("pid")] = sid + global connections, message_buffers + pid = data.get("pid") + connections[pid] = sid + # Initialize buffer if it doesn't exist + if pid not in message_buffers: + message_buffers[pid] = [] @sio.event async def disconnect(sid): - global connections + global connections, message_buffers logger.info(f">>> disconnect : sid={sid}") - # Remove websocket session from the store - if len(connections) > 0: - connections = {k: v for k, v in connections.items() if v != sid} + # Find the pid associated with the sid + pid = next((k for k, v in connections.items() if v == sid), None) + if pid: + # Remove pid from connections + del connections[pid] + + +async def buffered_emit(event, data, to_pid=None): + global connections, message_buffers + + connections = connections_reload() + sid = connections.get(to_pid) + logger.debug(f"sid: {sid} found for pid: {to_pid}") + + if sid: + try: + await sio.emit(event, data, room=sid) + except: + # If send fails, buffer the message + buffer_message(to_pid, event, data) + else: + # Buffer the message if the target is not connected + buffer_message(to_pid, event, data) + + +def buffer_message(pid, event, data): + global message_buffers + current_time = time.time() + if pid not in message_buffers: + message_buffers[pid] = [] + # Add message with timestamp and event name + message_buffers[pid].append((event, data, current_time)) + # Clean up old messages + message_buffers[pid] = [ + (msg_event, msg_data, timestamp) + for msg_event, msg_data, timestamp in message_buffers[pid] + if current_time - timestamp <= buffer_timeout + ] + + +@sio.event +async def fetch_buffered_messages(sid, pid): + global message_buffers + current_time = time.time() + if pid in message_buffers: + # Filter messages that are still valid (i.e., within the buffer_timeout) + valid_messages = [ + (msg_event, msg_data, timestamp) + for msg_event, msg_data, timestamp in message_buffers[pid] + if current_time - timestamp <= buffer_timeout + ] + # Emit each valid message + for event, data, _ in valid_messages: + await sio.emit(event, data, room=sid) + # Reassign the valid_messages back to message_buffers[pid] to clean up old messages + message_buffers[pid] = valid_messages def connections_reload(): diff --git a/oidc-controller/api/templates/verified_credentials.html b/oidc-controller/api/templates/verified_credentials.html index 474d560d..97cf2c35 100644 --- a/oidc-controller/api/templates/verified_credentials.html +++ b/oidc-controller/api/templates/verified_credentials.html @@ -112,6 +112,14 @@

Continue with:

> DEBUG Disconnect Web Socket + +
@@ -383,6 +391,8 @@
`Socket connecting. SID: ${this.socket.id}. PID: {{pid}}. Recovered? ${this.socket.recovered} ` ); this.socket.emit("initialize", { pid: "{{pid}}" }); + // Emit the `fetch_buffered_messages` event with `pid` as a string using Jinja templating + this.socket.emit('fetch_buffered_messages', '{{ pid }}'); }); this.socket.on("connect_error", (error) => {