Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added hierarchical lock for applying data updates #745

Merged
merged 13 commits into from
Feb 2, 2025
4 changes: 2 additions & 2 deletions packages/opal-client/opal_client/callbacks/reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ async def report_update_results(
status=result.status,
error=error_content,
)
except:
logger.exception("Failed to execute report_update_results")
except Exception as e:
logger.exception(f"Failed to execute report_update_results: {e}")
orweis marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 4 additions & 4 deletions packages/opal-client/opal_client/data/fetcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple

from opal_client.config import opal_client_config
from opal_client.policy_store.base_policy_store_client import JsonableValue
Expand Down Expand Up @@ -58,8 +58,8 @@ async def stop(self):
await self._engine.terminate_workers()

async def handle_url(
self, url: str, config: FetcherConfig, data: Optional[JsonableValue]
):
self, url: str, config: dict, data: Optional[JsonableValue]
) -> Optional[JsonableValue]:
"""Helper function wrapping self._engine.handle_url."""
if data is not None:
logger.info("Data provided inline for url: {url}", url=url)
Expand Down Expand Up @@ -107,7 +107,7 @@ async def handle_urls(
results_with_url_and_config = [
(url, config, result)
for (url, config, data), result in zip(urls, results)
if result is not None
if result is not None # FIXME ignores None results
orweis marked this conversation as resolved.
Show resolved Hide resolved
]

# return results
Expand Down
631 changes: 396 additions & 235 deletions packages/opal-client/opal_client/data/updater.py

Large diffs are not rendered by default.

40 changes: 27 additions & 13 deletions packages/opal-client/opal_client/tests/data_updater_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,18 @@ async def test_data_updater(server):
proc.terminate()

# test PATCH update event via API
entries = [
DataSourceEntry(
url="",
data=PATCH_DATA_UPDATE,
dst_path="/",
topics=DATA_TOPICS,
save_method="PATCH",
)
]
update = DataUpdate(
reason="Test_Patch", entries=entries, callback=UpdateCallback(callbacks=[])
reason="Test_Patch",
entries=[
DataSourceEntry(
url="",
data=PATCH_DATA_UPDATE,
dst_path="/",
topics=DATA_TOPICS,
save_method="PATCH",
)
],
callback=UpdateCallback(callbacks=[]),
)

headers = {"content-type": "application/json"}
Expand All @@ -218,13 +219,26 @@ async def test_data_updater(server):
)
assert res.status_code == 200
# value field is not specified for add operation should fail
entries[0].data = [{"op": "add", "path": "/"}]
res = requests.post(
DATA_UPDATE_ROUTE,
data=json.dumps(update, default=pydantic_encoder),
data=json.dumps(
{
"reason": "Test_Patch",
"entries": [
{
"url": "",
"data": [{"op": "add", "path": "/"}],
"dst_path": "/",
"topics": DATA_TOPICS,
"save_method": "PATCH",
}
],
},
default=pydantic_encoder,
),
headers=headers,
)
assert res.status_code == 422
assert res.status_code == 422, res.text


@pytest.mark.asyncio
Expand Down
31 changes: 28 additions & 3 deletions packages/opal-common/opal_common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import asyncio
import sys
from functools import partial
from typing import Any, Callable, Coroutine, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Coroutine, Optional, Set, Tuple, TypeVar

import loguru
from loguru import logger

if sys.version_info < (3, 10):
from typing_extensions import ParamSpec
Expand Down Expand Up @@ -94,16 +95,40 @@ async def stop_queue_handling(self):

class TasksPool:
def __init__(self):
self._tasks: List[asyncio.Task] = []
self._tasks: Set[asyncio.Task] = set()
self._running = True

def _cleanup_task(self, done_task):
self._tasks.remove(done_task)

def add_task(self, f):
if not self._running:
raise RuntimeError("TasksPool is already shutdown")
t = asyncio.create_task(f)
self._tasks.append(t)
self._tasks.add(t)
t.add_done_callback(self._cleanup_task)

async def shutdown(self, force: bool = False):
"""Wait for them to finish.

:param force: If True, cancel all tasks immediately.
"""
self._running = False
if force:
for t in self._tasks:
t.cancel()

results = await asyncio.gather(
*self._tasks,
return_exceptions=True,
)
for result in results:
if isinstance(result, Exception):
logger.exception(
"Error on task during shutdown of TasksPool: {result}",
result=result,
)


async def repeated_call(
func: Coroutine,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def queue_url(
self,
url: str,
callback: Coroutine,
config: Union[FetcherConfig, dict] = None,
config: Union[FetcherConfig, dict, None] = None,
fetcher="HttpFetchProvider",
) -> FetchEvent:
"""Simplified default fetching handler for queuing a fetch task.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
from contextlib import asynccontextmanager
from typing import Set


class HierarchicalLock:
"""A hierarchical lock for asyncio.

- If a path is locked, no ancestor or descendant path can be locked.
- Conversely, if a child path is locked, the parent path cannot be locked
until all child paths are released.
"""

def __init__(self):
# locked_paths: set of currently locked string paths
self._locked_paths: Set[str] = set()
# Map of tasks to their acquired locks for re-entrant protection
self._task_locks: dict[asyncio.Task, Set[str]] = {}
# Internal lock for synchronizing access to locked_paths
self._lock = asyncio.Lock()
# Condition to wake up tasks when a path is released
self._cond = asyncio.Condition(self._lock)

@staticmethod
def _is_conflicting(p1: str, p2: str) -> bool:
"""Check if two paths conflict with each other."""
return p1 == p2 or p1.startswith(p2) or p2.startswith(p1)

async def acquire(self, path: str):
"""Acquire the lock for the given hierarchical path.

If an ancestor or descendant path is locked, this will wait
until it is released.
"""
task = asyncio.current_task()
if task is None:
raise RuntimeError("acquire() must be called from within a task.")

async with self._lock:
# Prevent re-entrant locking by the same task
if path in self._task_locks.get(task, set()):
raise RuntimeError(f"Task {task} cannot re-acquire lock on '{path}'.")

# Wait until there is no conflict with existing locked paths
while any(self._is_conflicting(path, lp) for lp in self._locked_paths):
await self._cond.wait()

# Acquire the path
self._locked_paths.add(path)
if task not in self._task_locks:
self._task_locks[task] = set()
self._task_locks[task].add(path)

async def release(self, path: str):
"""Release the lock for the given path and notify waiting tasks."""
task = asyncio.current_task()
if task is None:
raise RuntimeError("release() must be called from within a task.")

async with self._lock:
if path not in self._locked_paths:
raise RuntimeError(f"Cannot release path '{path}' that is not locked.")

if path not in self._task_locks.get(task, set()):
raise RuntimeError(
f"Task {task} cannot release lock on '{path}' it does not hold."
)

# Remove the path from locked paths and task locks
self._locked_paths.remove(path)
self._task_locks[task].remove(path)
if not self._task_locks[task]:
del self._task_locks[task]

# Notify all tasks that something was released
self._cond.notify_all()

@asynccontextmanager
async def lock(self, path: str) -> "HierarchicalLock":
"""Acquire the lock for the given path and return a context manager."""
await self.acquire(path)
try:
yield self
finally:
await self.release(path)
Loading
Loading