diff --git a/agents-api/agents_api/common/interceptors.py b/agents-api/agents_api/common/interceptors.py index e600639e5..66c69a26e 100644 --- a/agents-api/agents_api/common/interceptors.py +++ b/agents-api/agents_api/common/interceptors.py @@ -13,7 +13,7 @@ from temporalio import workflow from temporalio.activity import _CompleteAsyncError as CompleteAsyncError -from temporalio.exceptions import ApplicationError, FailureError, TemporalError +from temporalio.exceptions import ActivityError, ApplicationError, FailureError, TemporalError from temporalio.service import RPCError from temporalio.worker import ( ActivityInboundInterceptor, @@ -241,6 +241,8 @@ def handle_execution_with_errors_sync[I, T]( except PASSTHROUGH_EXCEPTIONS: raise except BaseException as e: + while isinstance(e, ActivityError) and getattr(e, "__cause__", None): + e = e.__cause__ if not is_retryable_error(e): raise ApplicationError( str(e), diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index fdb13877e..61a121aa4 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -3,6 +3,8 @@ from temporalio import workflow from temporalio.exceptions import ApplicationError +from ..utils.workflows import get_workflow_name + with workflow.unsafe.imports_passed_through(): from pydantic import BaseModel, Field, computed_field from pydantic_partial import create_partial_model @@ -225,9 +227,13 @@ async def get_inputs(self) -> tuple[list[Any], list[str | None]]: limit=1000, direction="asc", ) # type: ignore[not-callable] + assert len(transitions) > 0, "No transitions found" inputs = [] labels = [] + workflow = get_workflow_name(transitions[-1]) + transitions = [t for t in transitions if get_workflow_name(t) == workflow] for transition in transitions: + # NOTE: The length hack should be refactored in case we want to implement multi-step control steps if transition.next and transition.next.step >= len(inputs): inputs.append(transition.output) labels.append(transition.step_label) diff --git a/agents-api/agents_api/common/utils/workflows.py b/agents-api/agents_api/common/utils/workflows.py new file mode 100644 index 000000000..3e081c643 --- /dev/null +++ b/agents-api/agents_api/common/utils/workflows.py @@ -0,0 +1,23 @@ +from ...autogen.openapi_model import Transition + +PAR_PREFIX = "PAR:" +SEPARATOR = "`" + + +def get_workflow_name(transition: Transition) -> str: + workflow_str = transition.current.workflow + if workflow_str.startswith(PAR_PREFIX): + # Extract between PAR:` and first ` after "workflow" + start_index = len(PAR_PREFIX) + len(SEPARATOR) + assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], ( + "Workflow string is too short or missing backtick" + ) + workflow_str = workflow_str[start_index:].split(SEPARATOR)[0] + elif workflow_str.startswith(SEPARATOR): + # Extract between backticks + start_index = len(SEPARATOR) + assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], ( + "Workflow string is too short or missing backtick" + ) + workflow_str = workflow_str[start_index:].split(SEPARATOR)[0] + return workflow_str diff --git a/agents-api/agents_api/workflows/task_execution/helpers.py b/agents-api/agents_api/workflows/task_execution/helpers.py index 10b7898f6..b840d164a 100644 --- a/agents-api/agents_api/workflows/task_execution/helpers.py +++ b/agents-api/agents_api/workflows/task_execution/helpers.py @@ -20,6 +20,7 @@ ExecutionInput, StepContext, ) + from ...common.utils.workflows import PAR_PREFIX, SEPARATOR from ...env import task_max_parallelism, temporal_heartbeat_timeout T = TypeVar("T") @@ -100,7 +101,9 @@ async def execute_switch_branch( workflow.logger.info(f"Switch step: Chose branch {index}") chosen_branch = switch[index] - case_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].case" + seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR + + case_wf_name = f"{seprated_workflow_name}[{context.cursor.step}].case" case_task = task.model_copy() case_task.workflows = [ @@ -138,7 +141,9 @@ async def execute_if_else_branch( if chosen_branch is None: chosen_branch = EvaluateStep(evaluate={"output": "_"}) - if_else_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].if_else" + separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR + + if_else_wf_name = f"{separated_workflow_name}[{context.cursor.step}].if_else" if_else_wf_name += ".then" if condition else ".else" if_else_task = task.model_copy() @@ -174,7 +179,9 @@ async def execute_foreach_step( results = [] for i, item in enumerate(items): - foreach_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]" + separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR + + foreach_wf_name = f"{separated_workflow_name}[{context.cursor.step}].foreach[{i}]" foreach_task = task.model_copy() foreach_task.workflows = [ Workflow(name=foreach_wf_name, steps=[do_step]), @@ -213,7 +220,9 @@ async def execute_map_reduce_step( reduce = "$ results + [_]" if reduce is None else reduce for i, item in enumerate(items): - workflow_name = f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}]" + separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR + + workflow_name = f"{separated_workflow_name}[{context.cursor.step}].mapreduce[{i}]" map_reduce_task = task.model_copy() map_reduce_task.workflows = [ Workflow(name=workflow_name, steps=[map_defn]), @@ -281,9 +290,9 @@ async def execute_map_reduce_step_parallel( for j, item in enumerate(batch): # Parallel batch workflow name # Note: Added PAR: prefix to easily identify parallel batches in logs - workflow_name = ( - f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]" - ) + separated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR + + workflow_name = f"{PAR_PREFIX}{separated_workflow_name}[{context.cursor.step}].mapreduce[{i}][{j}]" map_reduce_task = task.model_copy() map_reduce_task.workflows = [ Workflow(name=workflow_name, steps=[map_defn]), diff --git a/agents-api/agents_api/workflows/task_execution/transition.py b/agents-api/agents_api/workflows/task_execution/transition.py index 8dea5737e..5bd93210d 100644 --- a/agents-api/agents_api/workflows/task_execution/transition.py +++ b/agents-api/agents_api/workflows/task_execution/transition.py @@ -1,7 +1,7 @@ from datetime import timedelta from temporalio import workflow -from temporalio.exceptions import ApplicationError +from temporalio.exceptions import ActivityError, ApplicationError with workflow.unsafe.imports_passed_through(): from ...activities import task_steps @@ -61,6 +61,8 @@ async def transition( return transition_request except Exception as e: + while isinstance(e, ActivityError) and getattr(e, "__cause__", None): + e = e.__cause__ workflow.logger.error(f"Error in transition: {e!s}") msg = f"Error in transition: {e}" raise ApplicationError(msg) from e diff --git a/agents-api/tests/test_prepare_for_step.py b/agents-api/tests/test_prepare_for_step.py index 1a2fc7a67..f61cf0546 100644 --- a/agents-api/tests/test_prepare_for_step.py +++ b/agents-api/tests/test_prepare_for_step.py @@ -5,6 +5,7 @@ Agent, TaskSpecDef, ToolCallStep, + Transition, TransitionTarget, Workflow, ) @@ -13,7 +14,8 @@ StepContext, ) from agents_api.common.utils.datetime import utcnow -from ward import test +from agents_api.common.utils.workflows import get_workflow_name +from ward import raises, test @test("utility: prepare_for_step - underscore") @@ -86,3 +88,75 @@ async def _(): assert result["steps"]["first step"]["output"] == {"y": "2"} assert result["steps"]["second step"]["input"] == {"y": "2"} assert result["steps"]["second step"]["output"] == {"z": "3"} + + +@test("utility: get_workflow_name") +async def _(): + transition = Transition( + id=uuid.uuid4(), + execution_id=uuid.uuid4(), + output=None, + created_at=utcnow(), + updated_at=utcnow(), + type="step", + current=TransitionTarget(workflow="main", step=0), + next=TransitionTarget(workflow="main", step=1), + ) + + transition.current = TransitionTarget(workflow="main", step=0) + transition.next = TransitionTarget(workflow="main", step=1) + assert get_workflow_name(transition) == "main" + + transition.current = TransitionTarget(workflow="`main`[0].if_else.then", step=0) + transition.next = None + assert get_workflow_name(transition) == "main" + + transition.current = TransitionTarget(workflow="subworkflow", step=0) + transition.next = TransitionTarget(workflow="subworkflow", step=1) + assert get_workflow_name(transition) == "subworkflow" + + transition.current = TransitionTarget(workflow="`subworkflow`[0].if_else.then", step=0) + transition.next = TransitionTarget(workflow="`subworkflow`[0].if_else.else", step=1) + assert get_workflow_name(transition) == "subworkflow" + + transition.current = TransitionTarget(workflow="PAR:`main`[2].mapreduce[0][2],0", step=0) + transition.next = None + assert get_workflow_name(transition) == "main" + + transition.current = TransitionTarget( + workflow="PAR:`subworkflow`[2].mapreduce[0][3],0", step=0 + ) + transition.next = None + assert get_workflow_name(transition) == "subworkflow" + + +@test("utility: get_workflow_name - raises") +async def _(): + transition = Transition( + id=uuid.uuid4(), + execution_id=uuid.uuid4(), + output=None, + created_at=utcnow(), + updated_at=utcnow(), + type="step", + current=TransitionTarget(workflow="main", step=0), + next=TransitionTarget(workflow="main", step=1), + ) + + with raises(AssertionError): + transition.current = TransitionTarget(workflow="`main[2].mapreduce[0][2],0", step=0) + get_workflow_name(transition) + + with raises(AssertionError): + transition.current = TransitionTarget(workflow="PAR:`", step=0) + get_workflow_name(transition) + + with raises(AssertionError): + transition.current = TransitionTarget(workflow="`", step=0) + get_workflow_name(transition) + + with raises(AssertionError): + transition.current = TransitionTarget( + workflow="PAR:`subworkflow[2].mapreduce[0][3],0", step=0 + ) + get_workflow_name(transition)