Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(agent-api): fix the list of steps accessable by a subworkflow #1131

Merged
merged 13 commits into from
Feb 7, 2025
Merged
4 changes: 3 additions & 1 deletion agents-api/agents_api/common/interceptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from temporalio import workflow
from temporalio.activity import _CompleteAsyncError as CompleteAsyncError
from temporalio.exceptions import ApplicationError, FailureError, TemporalError
from temporalio.exceptions import ActivityError, ApplicationError, FailureError, TemporalError
from temporalio.service import RPCError
from temporalio.worker import (
ActivityInboundInterceptor,
Expand Down Expand Up @@ -241,6 +241,8 @@ def handle_execution_with_errors_sync[I, T](
except PASSTHROUGH_EXCEPTIONS:
raise
except BaseException as e:
while isinstance(e, ActivityError) and getattr(e, "__cause__", None):
e = e.__cause__
if not is_retryable_error(e):
raise ApplicationError(
str(e),
Expand Down
6 changes: 6 additions & 0 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from temporalio import workflow
from temporalio.exceptions import ApplicationError

from ..utils.workflows import get_workflow_name

with workflow.unsafe.imports_passed_through():
from pydantic import BaseModel, Field, computed_field
from pydantic_partial import create_partial_model
Expand Down Expand Up @@ -225,9 +227,13 @@ async def get_inputs(self) -> tuple[list[Any], list[str | None]]:
limit=1000,
direction="asc",
) # type: ignore[not-callable]
assert len(transitions) > 0, "No transitions found"
Ahmad-mtos marked this conversation as resolved.
Show resolved Hide resolved
inputs = []
labels = []
workflow = get_workflow_name(transitions[-1])
transitions = [t for t in transitions if get_workflow_name(t) == workflow]
for transition in transitions:
# NOTE: The length hack should be refactored in case we want to implement multi-step control steps
if transition.next and transition.next.step >= len(inputs):
inputs.append(transition.output)
labels.append(transition.step_label)
Expand Down
22 changes: 22 additions & 0 deletions agents-api/agents_api/common/utils/workflows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from ...autogen.openapi_model import Transition

PAR_PREFIX = "PAR:"
SEPARATOR = "`"

def get_workflow_name(transition: Transition) -> str:
workflow_str = transition.current.workflow
if workflow_str.startswith(PAR_PREFIX):
# Extract between PAR:` and first ` after "workflow"
start_index = len(PAR_PREFIX) + len(SEPARATOR)
assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], (
"Workflow string is too short or missing backtick"
)
workflow_str = workflow_str[start_index:].split(SEPARATOR)[0]
elif workflow_str.startswith(SEPARATOR):
# Extract between backticks
start_index = len(SEPARATOR)
assert len(workflow_str) > start_index and SEPARATOR in workflow_str[start_index:], (
"Workflow string is too short or missing backtick"
)
workflow_str = workflow_str[start_index:].split(SEPARATOR)[0]
return workflow_str
23 changes: 16 additions & 7 deletions agents-api/agents_api/workflows/task_execution/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StepContext,
)
from ...env import task_max_parallelism, temporal_heartbeat_timeout
from ...common.utils.workflows import PAR_PREFIX, SEPARATOR

T = TypeVar("T")

Expand Down Expand Up @@ -100,7 +101,9 @@ async def execute_switch_branch(
workflow.logger.info(f"Switch step: Chose branch {index}")
chosen_branch = switch[index]

case_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].case"
seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR
Ahmad-mtos marked this conversation as resolved.
Show resolved Hide resolved

case_wf_name = f"{seprated_workflow_name}[{context.cursor.step}].case"

case_task = task.model_copy()
case_task.workflows = [
Expand Down Expand Up @@ -138,7 +141,9 @@ async def execute_if_else_branch(
if chosen_branch is None:
chosen_branch = EvaluateStep(evaluate={"output": "_"})

if_else_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].if_else"
seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR

if_else_wf_name = f"{seprated_workflow_name}[{context.cursor.step}].if_else"
if_else_wf_name += ".then" if condition else ".else"

if_else_task = task.model_copy()
Expand Down Expand Up @@ -174,7 +179,9 @@ async def execute_foreach_step(
results = []

for i, item in enumerate(items):
foreach_wf_name = f"`{context.cursor.workflow}`[{context.cursor.step}].foreach[{i}]"
seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR

foreach_wf_name = f"{seprated_workflow_name}[{context.cursor.step}].foreach[{i}]"
foreach_task = task.model_copy()
foreach_task.workflows = [
Workflow(name=foreach_wf_name, steps=[do_step]),
Expand Down Expand Up @@ -213,7 +220,9 @@ async def execute_map_reduce_step(
reduce = "$ results + [_]" if reduce is None else reduce

for i, item in enumerate(items):
workflow_name = f"`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}]"
seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR

workflow_name = f"{seprated_workflow_name}[{context.cursor.step}].mapreduce[{i}]"
map_reduce_task = task.model_copy()
map_reduce_task.workflows = [
Workflow(name=workflow_name, steps=[map_defn]),
Expand Down Expand Up @@ -281,9 +290,9 @@ async def execute_map_reduce_step_parallel(
for j, item in enumerate(batch):
# Parallel batch workflow name
# Note: Added PAR: prefix to easily identify parallel batches in logs
workflow_name = (
f"PAR:`{context.cursor.workflow}`[{context.cursor.step}].mapreduce[{i}][{j}]"
)
seprated_workflow_name = SEPARATOR + context.cursor.workflow + SEPARATOR

workflow_name = f"{PAR_PREFIX}{seprated_workflow_name}[{context.cursor.step}].mapreduce[{i}][{j}]"
map_reduce_task = task.model_copy()
map_reduce_task.workflows = [
Workflow(name=workflow_name, steps=[map_defn]),
Expand Down
4 changes: 3 additions & 1 deletion agents-api/agents_api/workflows/task_execution/transition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import timedelta

from temporalio import workflow
from temporalio.exceptions import ApplicationError
from temporalio.exceptions import ActivityError, ApplicationError

with workflow.unsafe.imports_passed_through():
from ...activities import task_steps
Expand Down Expand Up @@ -61,6 +61,8 @@ async def transition(
return transition_request

except Exception as e:
while isinstance(e, ActivityError) and getattr(e, "__cause__", None):
e = e.__cause__
workflow.logger.error(f"Error in transition: {e!s}")
msg = f"Error in transition: {e}"
raise ApplicationError(msg) from e
76 changes: 75 additions & 1 deletion agents-api/tests/test_prepare_for_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Agent,
TaskSpecDef,
ToolCallStep,
Transition,
TransitionTarget,
Workflow,
)
Expand All @@ -13,7 +14,8 @@
StepContext,
)
from agents_api.common.utils.datetime import utcnow
from ward import test
from agents_api.common.utils.workflows import get_workflow_name
from ward import raises, test


@test("utility: prepare_for_step - underscore")
Expand Down Expand Up @@ -86,3 +88,75 @@ async def _():
assert result["steps"]["first step"]["output"] == {"y": "2"}
assert result["steps"]["second step"]["input"] == {"y": "2"}
assert result["steps"]["second step"]["output"] == {"z": "3"}


@test("utility: get_workflow_name")
async def _():
transition = Transition(
id=uuid.uuid4(),
execution_id=uuid.uuid4(),
output=None,
created_at=utcnow(),
updated_at=utcnow(),
type="step",
current=TransitionTarget(workflow="main", step=0),
next=TransitionTarget(workflow="main", step=1),
)

transition.current = TransitionTarget(workflow="main", step=0)
Ahmad-mtos marked this conversation as resolved.
Show resolved Hide resolved
transition.next = TransitionTarget(workflow="main", step=1)
assert get_workflow_name(transition) == "main"

transition.current = TransitionTarget(workflow="`main`[0].if_else.then", step=0)
transition.next = None
assert get_workflow_name(transition) == "main"

transition.current = TransitionTarget(workflow="subworkflow", step=0)
transition.next = TransitionTarget(workflow="subworkflow", step=1)
assert get_workflow_name(transition) == "subworkflow"

transition.current = TransitionTarget(workflow="`subworkflow`[0].if_else.then", step=0)
transition.next = TransitionTarget(workflow="`subworkflow`[0].if_else.else", step=1)
assert get_workflow_name(transition) == "subworkflow"

transition.current = TransitionTarget(workflow="PAR:`main`[2].mapreduce[0][2],0", step=0)
transition.next = None
assert get_workflow_name(transition) == "main"

transition.current = TransitionTarget(
workflow="PAR:`subworkflow`[2].mapreduce[0][3],0", step=0
)
transition.next = None
assert get_workflow_name(transition) == "subworkflow"


@test("utility: get_workflow_name - raises")
async def _():
transition = Transition(
id=uuid.uuid4(),
execution_id=uuid.uuid4(),
output=None,
created_at=utcnow(),
updated_at=utcnow(),
type="step",
current=TransitionTarget(workflow="main", step=0),
next=TransitionTarget(workflow="main", step=1),
)

with raises(AssertionError):
transition.current = TransitionTarget(workflow="`main[2].mapreduce[0][2],0", step=0)
get_workflow_name(transition)

with raises(AssertionError):
transition.current = TransitionTarget(workflow="PAR:`", step=0)
get_workflow_name(transition)

with raises(AssertionError):
transition.current = TransitionTarget(workflow="`", step=0)
get_workflow_name(transition)

with raises(AssertionError):
transition.current = TransitionTarget(
workflow="PAR:`subworkflow[2].mapreduce[0][3],0", step=0
)
get_workflow_name(transition)