From 458cdb2d896ad9cbe205db0c9ddd4c9b999177be Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Thu, 6 Feb 2025 01:11:30 +0000 Subject: [PATCH] Revert "perf(engine): Improved DAG scheduling algorithm (#827)" This reverts commit 2fd1036b16263ccd19e9405da39283c3fc02cfb2. --- tracecat/dsl/scheduler.py | 39 ++++++++------------------------------- 1 file changed, 8 insertions(+), 31 deletions(-) diff --git a/tracecat/dsl/scheduler.py b/tracecat/dsl/scheduler.py index f3e3ad366..9bd275f0f 100644 --- a/tracecat/dsl/scheduler.py +++ b/tracecat/dsl/scheduler.py @@ -3,7 +3,6 @@ from collections.abc import Coroutine from typing import Any -from temporalio import workflow from temporalio.exceptions import ApplicationError from tracecat.contexts import ctx_logger @@ -221,34 +220,23 @@ async def start(self) -> dict[str, TaskExceptionInfo] | None: for task_ref, indegree in self.indegrees.items(): if indegree == 0: self.queue.put_nowait(task_ref) - - pending_tasks: set[asyncio.Task[None]] = set() - while not self.task_exceptions and ( - not self.queue.empty() - or len(self.completed_tasks) < len(self.tasks) - or pending_tasks + not self.queue.empty() or len(self.completed_tasks) < len(self.tasks) ): self.logger.debug( "Waiting for tasks.", qsize=self.queue.qsize(), n_visited=len(self.completed_tasks), n_tasks=len(self.tasks), - n_pending=len(pending_tasks), ) + try: + task_ref = await asyncio.wait_for( + self.queue.get(), timeout=self._queue_wait_timeout + ) + except TimeoutError: + continue - # Clean up completed tasks - done_tasks = {t for t in pending_tasks if t.done()} - pending_tasks.difference_update(done_tasks) - - if not self.queue.empty(): - task_ref = await self.queue.get() - task = asyncio.create_task(self._schedule_task(task_ref)) - pending_tasks.add(task) - elif pending_tasks: - # Wait for at least one pending task to complete - await workflow.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED) - + asyncio.create_task(self._schedule_task(task_ref)) if self.task_exceptions: self.logger.warning( "DSLScheduler got task exceptions, stopping...", @@ -257,17 +245,6 @@ async def start(self) -> dict[str, TaskExceptionInfo] | None: n_visited=len(self.completed_tasks), n_tasks=len(self.tasks), ) - # Cancel all pending tasks and wait for them to complete - for task in pending_tasks: - if not task.done(): - task.cancel() - - if pending_tasks: - try: - await asyncio.gather(*pending_tasks, return_exceptions=True) - except Exception as e: - self.logger.warning("Error while canceling tasks", error=e) - return self.task_exceptions self.logger.info( "All tasks completed",