Skip to content

Commit

Permalink
Revert "perf(engine): Improved DAG scheduling algorithm (#827)"
Browse files Browse the repository at this point in the history
This reverts commit 2fd1036.
  • Loading branch information
daryllimyt committed Feb 6, 2025
1 parent 5372aee commit 458cdb2
Showing 1 changed file with 8 additions and 31 deletions.
39 changes: 8 additions & 31 deletions tracecat/dsl/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Coroutine
from typing import Any

from temporalio import workflow
from temporalio.exceptions import ApplicationError

from tracecat.contexts import ctx_logger
Expand Down Expand Up @@ -221,34 +220,23 @@ async def start(self) -> dict[str, TaskExceptionInfo] | None:
for task_ref, indegree in self.indegrees.items():
if indegree == 0:
self.queue.put_nowait(task_ref)

pending_tasks: set[asyncio.Task[None]] = set()

while not self.task_exceptions and (
not self.queue.empty()
or len(self.completed_tasks) < len(self.tasks)
or pending_tasks
not self.queue.empty() or len(self.completed_tasks) < len(self.tasks)
):
self.logger.debug(
"Waiting for tasks.",
qsize=self.queue.qsize(),
n_visited=len(self.completed_tasks),
n_tasks=len(self.tasks),
n_pending=len(pending_tasks),
)
try:
task_ref = await asyncio.wait_for(
self.queue.get(), timeout=self._queue_wait_timeout
)
except TimeoutError:
continue

# Clean up completed tasks
done_tasks = {t for t in pending_tasks if t.done()}
pending_tasks.difference_update(done_tasks)

if not self.queue.empty():
task_ref = await self.queue.get()
task = asyncio.create_task(self._schedule_task(task_ref))
pending_tasks.add(task)
elif pending_tasks:
# Wait for at least one pending task to complete
await workflow.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED)

asyncio.create_task(self._schedule_task(task_ref))
if self.task_exceptions:
self.logger.warning(
"DSLScheduler got task exceptions, stopping...",
Expand All @@ -257,17 +245,6 @@ async def start(self) -> dict[str, TaskExceptionInfo] | None:
n_visited=len(self.completed_tasks),
n_tasks=len(self.tasks),
)
# Cancel all pending tasks and wait for them to complete
for task in pending_tasks:
if not task.done():
task.cancel()

if pending_tasks:
try:
await asyncio.gather(*pending_tasks, return_exceptions=True)
except Exception as e:
self.logger.warning("Error while canceling tasks", error=e)

return self.task_exceptions
self.logger.info(
"All tasks completed",
Expand Down

0 comments on commit 458cdb2

Please sign in to comment.