Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Structure/Task Context Improvements #1259

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Request/response debug logging to all Prompt Drivers.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.
- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`.

### Changed

- **BREAKING**: `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead.
- `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`.
- `Pipeline.context["parent_output"]` has changed type from `str | None` to `BaseArtifact | None`.
- `_DefaultsConfig.logging_config` and `Defaults.drivers_config` are now lazily instantiated.
- `BaseTask.add_parent`/`BaseTask.add_child` now only add the parent/child task to the structure if it is not already present.
- `BaseEventListener.flush_events()` to flush events from an Event Listener.
- `BaseEventListener` no longer requires a thread lock for batching events.
- Updated `TookitTask` system prompt to retry/fix actions when using native tool calling.
- Updated `ToolkitTask` system prompt to retry/fix actions when using native tool calling.

### Fixed

Expand Down
7 changes: 4 additions & 3 deletions docs/griptape-framework/structures/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ You can access the final output of the Pipeline by using the [output](../../refe

Pipelines have access to the following [context](../../reference/griptape/structures/pipeline.md#griptape.structures.pipeline.Pipeline.context) variables in addition to the [base context](./tasks.md#context).

- `parent_output`: output from the parent.
- `parent`: parent task.
- `child`: child task.
- `task_outputs`: dictionary containing mapping of all task IDs to their outputs.
- `parent_output`: output from the parent task if one exists, otherwise `None`.
- `parent`: parent task if one exists, otherwise `None`.
- `child`: child task if one exists, otherwise `None`.

## Pipeline

Expand Down
7 changes: 4 additions & 3 deletions docs/griptape-framework/structures/workflows.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ You can access the final output of the Workflow by using the [output](../../refe

Workflows have access to the following [context](../../reference/griptape/structures/workflow.md#griptape.structures.workflow.Workflow.context) variables in addition to the [base context](./tasks.md#context):

- `parent_outputs`: dictionary containing mapping of parent IDs to their outputs.
- `task_outputs`: dictionary containing mapping of all task IDs to their outputs.
- `parent_outputs`: dictionary containing mapping of parent task IDs to their outputs.
- `parents_output_text`: string containing the concatenated outputs of all parent tasks.
- `parents`: parent tasks referenceable by IDs.
- `children`: child tasks referenceable by IDs.
- `parents`: dictionary containing mapping of parent task IDs to their task objects.
- `children`: dictionary containing mapping of child task IDs to their task objects.

## Workflow

Expand Down
3 changes: 2 additions & 1 deletion griptape/structures/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def context(self, task: BaseTask) -> dict[str, Any]:

context.update(
{
"parent_output": task.parents[0].output.to_text() if task.parents and task.parents[0].output else None,
"parent_output": task.parents[0].output if task.parents else None,
"task_outputs": self.task_outputs,
"parent": task.parents[0] if task.parents else None,
"child": task.children[0] if task.children else None,
},
Expand Down
4 changes: 4 additions & 0 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def output(self) -> BaseArtifact:
raise ValueError("Structure's output Task has no output. Run the Structure to generate output.")
return self.output_task.output

@property
def task_outputs(self) -> dict[str, Optional[BaseArtifact]]:
return {task.id: task.output for task in self.tasks}

@property
def finished_tasks(self) -> list[BaseTask]:
return [s for s in self.tasks if s.is_finished()]
Expand Down
1 change: 1 addition & 0 deletions griptape/structures/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def context(self, task: BaseTask) -> dict[str, Any]:

context.update(
{
"task_outputs": self.task_outputs,
"parent_outputs": task.parent_outputs,
"parents_output_text": task.parents_output_text,
"parents": {parent.id: parent for parent in task.parents},
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def children(self) -> list[BaseTask]:
raise ValueError("Structure must be set to access children")

@property
def parent_outputs(self) -> dict[str, str]:
return {parent.id: parent.output.to_text() if parent.output else "" for parent in self.parents}
def parent_outputs(self) -> dict[str, BaseArtifact]:
return {parent.id: parent.output for parent in self.parents if parent.output}

@property
def parents_output_text(self) -> str:
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/structures/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,16 @@ def finished_tasks(self):
def test_fail_fast(self):
with pytest.raises(ValueError):
Agent(prompt_driver=MockPromptDriver(), fail_fast=True)

def test_task_outputs(self):
task = PromptTask("test prompt")
agent = Agent(prompt_driver=MockPromptDriver())

agent.add_task(task)

assert len(agent.task_outputs) == 1
assert agent.task_outputs[task.id] is None
agent.run("hello")

assert len(agent.task_outputs) == 1
assert agent.task_outputs[task.id] == task.output
15 changes: 14 additions & 1 deletion tests/unit/structures/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,8 @@ def test_context(self):

context = pipeline.context(task)

assert context["parent_output"] == parent.output.to_text()
assert context["parent_output"] == parent.output
assert context["task_outputs"] == pipeline.task_outputs
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
Expand Down Expand Up @@ -398,3 +399,15 @@ def test_add_duplicate_task_directly(self):

with pytest.raises(ValueError, match=f"Duplicate task with id {task.id} found."):
pipeline.run()

def test_task_outputs(self):
task = PromptTask("test")
pipeline = Pipeline()

pipeline + [task]

assert len(pipeline.task_outputs) == 1
assert pipeline.task_outputs[task.id] is None
pipeline.run()
assert len(pipeline.task_outputs) == 1
assert pipeline.task_outputs[task.id] == task.output
17 changes: 15 additions & 2 deletions tests/unit/structures/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,13 +737,14 @@ def test_context(self):

context = workflow.context(task)

assert context["parent_outputs"] == {parent.id: ""}
assert context["parent_outputs"] == {}

workflow.run()

context = workflow.context(task)

assert context["parent_outputs"] == {parent.id: parent.output.to_text()}
assert context["task_outputs"] == workflow.task_outputs
assert context["parent_outputs"] == {parent.id: parent.output}
assert context["parents_output_text"] == "mock output"
assert context["structure"] == workflow
assert context["parents"] == {parent.id: parent}
Expand Down Expand Up @@ -966,3 +967,15 @@ def _validate_topology_4(workflow) -> None:
publish_website = workflow.find_task("publish_website")
assert publish_website.parent_ids == ["compare_movies"]
assert publish_website.child_ids == ["summarize_to_slack"]

def test_task_outputs(self):
task = PromptTask("test")
workflow = Workflow(tasks=[task])

assert len(workflow.task_outputs) == 1
assert workflow.task_outputs[task.id] is None

workflow.run()

assert len(workflow.task_outputs) == 1
assert workflow.task_outputs[task.id].value == "mock output"
5 changes: 2 additions & 3 deletions tests/unit/tasks/test_base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def test_parent_outputs(self, task):

parent_3.output = None
assert child.parent_outputs == {
parent_1.id: parent_1.output.to_text(),
parent_2.id: parent_2.output.to_text(),
parent_3.id: "",
parent_1.id: parent_1.output,
parent_2.id: parent_2.output,
}

def test_parents_output(self, task):
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/tasks/test_base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_full_context(self):
context = subtask.full_context

assert context["foo"] == "bar"
assert context["parent_output"] == parent.output.to_text()
assert context["parent_output"] == parent.output
assert context["structure"] == pipeline
assert context["parent"] == parent
assert context["child"] == child
Expand Down
Loading