diff --git a/packages/opal-client/opal_client/callbacks/reporter.py b/packages/opal-client/opal_client/callbacks/reporter.py index 5e206d3b3..264f45b51 100644 --- a/packages/opal-client/opal_client/callbacks/reporter.py +++ b/packages/opal-client/opal_client/callbacks/reporter.py @@ -1,5 +1,5 @@ import json -from typing import List, Optional +from typing import Any, Awaitable, Callable, Dict, List, Optional import aiohttp from opal_client.callbacks.register import CallbackConfig, CallbacksRegister @@ -9,6 +9,8 @@ from opal_common.logger import logger from opal_common.schemas.data import DataUpdateReport +GetUserDataHandler = Callable[[DataUpdateReport], Awaitable[Dict[str, Any]]] + class CallbacksReporter: """can send a report to callbacks registered on the callback register.""" @@ -18,6 +20,7 @@ def __init__( ) -> None: self._register = register self._fetcher = data_fetcher or DataFetcher() + self._get_user_data_handler: Optional[GetUserDataHandler] = None async def start(self): await self._fetcher.start() @@ -25,6 +28,11 @@ async def start(self): async def stop(self): await self._fetcher.stop() + def set_user_data_handler(self, handler: GetUserDataHandler): + if self._get_user_data_handler is not None: + logger.warning("set_user_data_handler called and already have a handler.") + self._get_user_data_handler = handler + async def report_update_results( self, report: DataUpdateReport, @@ -33,6 +41,9 @@ async def report_update_results( try: # all the urls that will be eventually called by the fetcher urls = [] + if self._get_user_data_handler is not None: + report = report.copy() + report.user_data = await self._get_user_data_handler(report) report_data = report.json() # first we add the callback urls from the callback register diff --git a/packages/opal-client/opal_client/data/updater.py b/packages/opal-client/opal_client/data/updater.py index 544a6dc45..b26eb128e 100644 --- a/packages/opal-client/opal_client/data/updater.py +++ b/packages/opal-client/opal_client/data/updater.py @@ -508,3 +508,7 @@ async def _set_policy_data( await tx.set_policy_data(data, path=path) else: await tx.patch_policy_data(data, path=path) + + @property + def callbacks_reporter(self) -> CallbacksReporter: + return self._callbacks_reporter diff --git a/packages/opal-client/opal_client/policy/updater.py b/packages/opal-client/opal_client/policy/updater.py index 20ee44bad..10e0dbd91 100644 --- a/packages/opal-client/opal_client/policy/updater.py +++ b/packages/opal-client/opal_client/policy/updater.py @@ -352,3 +352,11 @@ async def handle_policy_updates(self): break except Exception: logger.exception("Failed to update policy") + + @property + def topics(self) -> List[str]: + return self._topics + + @property + def callbacks_reporter(self) -> CallbacksReporter: + return self._callbacks_reporter diff --git a/packages/opal-common/opal_common/schemas/data.py b/packages/opal-common/opal_common/schemas/data.py index 3d5833e22..37378a984 100644 --- a/packages/opal-common/opal_common/schemas/data.py +++ b/packages/opal-common/opal_common/schemas/data.py @@ -176,3 +176,4 @@ class DataUpdateReport(BaseModel): reports: List[DataEntryReport] # in case this is a policy update, the new hash committed the policy store. policy_hash: Optional[str] = None + user_data: Dict[str, Any] = {}