From e284f36686ac3f228d84af8e5719ce832be3900d Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 11 Dec 2024 16:29:07 -0800 Subject: [PATCH] Change conversation_memory_strategy from enum to string --- CHANGELOG.md | 2 +- .../structures/conversation-memory.md | 2 +- .../structures/src/conversation_memory_per_task.py | 2 +- .../conversation_memory_per_task_with_disabled.py | 2 +- griptape/schemas/base_schema.py | 1 - griptape/structures/structure.py | 13 ++++--------- griptape/tasks/prompt_task.py | 2 +- tests/unit/structures/test_agent.py | 2 +- tests/unit/structures/test_pipeline.py | 2 +- tests/unit/structures/test_structure.py | 8 ++++---- tests/unit/structures/test_workflow.py | 2 +- 11 files changed, 16 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c213088f5..4e944ef29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `PromptTask.conversation_memory` for setting the Conversation Memory on a Prompt Task. -- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `Structure.ConversationMemoryStrategy.PER_STRUCTURE`. +- `Structure.conversation_memory_strategy` for setting whether Conversation Memory Runs should be created on a per-Structure or per-Task basis. Default is `per_structure`. ## [1.0.0] - 2024-12-09 diff --git a/docs/griptape-framework/structures/conversation-memory.md b/docs/griptape-framework/structures/conversation-memory.md index 6e1b393d6..98f316b9b 100644 --- a/docs/griptape-framework/structures/conversation-memory.md +++ b/docs/griptape-framework/structures/conversation-memory.md @@ -46,7 +46,7 @@ In this example, the `improve` Task is "forgotten" after the Structure's run is #### Per Task -You can change when Conversation Memory Runs are created by modifying [Structure.conversation_memory_strategy](../../reference/griptape/structures/structure.md#griptape.structures.Structure.conversation_memory_strategy) from the default [PER_STRUCTURE](../../reference/griptape/structures/structure.md#griptape.structures.structure.ConversationMemoryStrategy.PER_STRUCTURE) to [PER_TASK](../../reference/griptape/structures/structure.md#griptape.structures.structure.ConversationMemoryStrategy.PER_TASK). +You can change when Conversation Memory Runs are created by modifying [Structure.conversation_memory_strategy](../../reference/griptape/structures/structure.md#griptape.structures.Structure.conversation_memory_strategy) from the default `per_structure` to `per_task`. ```python --8<-- "docs/griptape-framework/structures/src/conversation_memory_per_task.py" diff --git a/docs/griptape-framework/structures/src/conversation_memory_per_task.py b/docs/griptape-framework/structures/src/conversation_memory_per_task.py index 72d9c663d..3d3607418 100644 --- a/docs/griptape-framework/structures/src/conversation_memory_per_task.py +++ b/docs/griptape-framework/structures/src/conversation_memory_per_task.py @@ -2,7 +2,7 @@ from griptape.tasks import PromptTask pipeline = Pipeline( - conversation_memory_strategy=Pipeline.ConversationMemoryStrategy.PER_TASK, + conversation_memory_strategy="per_task", tasks=[ PromptTask("Respond to this request: {{ args[0] }}", id="input"), PromptTask("Improve the writing", id="improve"), diff --git a/docs/griptape-framework/structures/src/conversation_memory_per_task_with_disabled.py b/docs/griptape-framework/structures/src/conversation_memory_per_task_with_disabled.py index 4d7711f6a..7db12c422 100644 --- a/docs/griptape-framework/structures/src/conversation_memory_per_task_with_disabled.py +++ b/docs/griptape-framework/structures/src/conversation_memory_per_task_with_disabled.py @@ -2,7 +2,7 @@ from griptape.tasks import PromptTask pipeline = Pipeline( - conversation_memory_strategy=Pipeline.ConversationMemoryStrategy.PER_TASK, + conversation_memory_strategy="per_task", tasks=[ PromptTask("Respond to this request: {{ args[0] }}", id="input"), PromptTask("Improve the writing of this: {{ parent_output }}", id="improve", conversation_memory=None), diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 2b50fc125..4a049752b 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -208,7 +208,6 @@ def _resolve_types(cls, attrs_cls: type) -> None: "Sequence": Sequence, "TaskMemory": TaskMemory, "State": BaseTask.State, - "ConversationMemoryStrategy": Structure.ConversationMemoryStrategy, "BaseConversationMemory": BaseConversationMemory, "BaseArtifactStorage": BaseArtifactStorage, # Third party modules diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index eb18adde7..3b9c36317 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -2,8 +2,7 @@ import uuid from abc import ABC, abstractmethod -from enum import Enum -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from attrs import Factory, define, field @@ -24,10 +23,6 @@ @define class Structure(RuleMixin, SerializableMixin, RunnableMixin["Structure"], ABC): - class ConversationMemoryStrategy(Enum): - PER_STRUCTURE = 1 - PER_TASK = 2 - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) _tasks: list[Union[BaseTask, list[BaseTask]]] = field( factory=list, kw_only=True, alias="tasks", metadata={"serializable": True} @@ -37,8 +32,8 @@ class ConversationMemoryStrategy(Enum): kw_only=True, metadata={"serializable": True}, ) - conversation_memory_strategy: ConversationMemoryStrategy = field( - default=ConversationMemoryStrategy.PER_STRUCTURE, kw_only=True, metadata={"serializable": True} + conversation_memory_strategy: Literal["per_structure", "per_task"] = field( + default="per_structure", kw_only=True, metadata={"serializable": True} ) task_memory: TaskMemory = field( default=Factory(lambda self: TaskMemory(), takes_self=True), @@ -172,7 +167,7 @@ def after_run(self) -> None: if self.output_task is not None: if ( - self.conversation_memory_strategy == self.ConversationMemoryStrategy.PER_STRUCTURE + self.conversation_memory_strategy == "per_structure" and self.conversation_memory is not None and self.input_task is not None and self.output_task.output is not None diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index b73bb39b8..e670af4b8 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -107,7 +107,7 @@ def after_run(self) -> None: structure = self.structure if ( structure is not None - and structure.conversation_memory_strategy == structure.ConversationMemoryStrategy.PER_TASK + and structure.conversation_memory_strategy == "per_task" and self.conversation_memory is not None and self.output is not None ): diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index fbf401e13..d512cf465 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -279,7 +279,7 @@ def test_to_dict(self): "meta": agent.conversation_memory.meta, "max_runs": agent.conversation_memory.max_runs, }, - "conversation_memory_strategy": str(agent.conversation_memory_strategy), + "conversation_memory_strategy": agent.conversation_memory_strategy, } assert agent.to_dict() == expected_agent_dict diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index 0f2e598a3..10071decb 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -436,7 +436,7 @@ def test_to_dict(self): "meta": pipeline.conversation_memory.meta, "max_runs": pipeline.conversation_memory.max_runs, }, - "conversation_memory_strategy": str(pipeline.conversation_memory_strategy), + "conversation_memory_strategy": pipeline.conversation_memory_strategy, "fail_fast": pipeline.fail_fast, } assert pipeline.to_dict() == expected_pipeline_dict diff --git a/tests/unit/structures/test_structure.py b/tests/unit/structures/test_structure.py index 4835200f6..d7183e755 100644 --- a/tests/unit/structures/test_structure.py +++ b/tests/unit/structures/test_structure.py @@ -1,6 +1,6 @@ import pytest -from griptape.structures import Agent, Pipeline, Structure +from griptape.structures import Agent, Pipeline from griptape.tasks import PromptTask @@ -20,7 +20,7 @@ def test_output(self): def test_conversation_mode_per_structure(self): pipeline = Pipeline( - conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_STRUCTURE, + conversation_memory_strategy="per_structure", tasks=[PromptTask("1"), PromptTask("2")], ) @@ -32,7 +32,7 @@ def test_conversation_mode_per_structure(self): def test_conversation_mode_per_task(self): pipeline = Pipeline( - conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_TASK, + conversation_memory_strategy="per_task", tasks=[PromptTask("1"), PromptTask("2")], ) @@ -46,7 +46,7 @@ def test_conversation_mode_per_task(self): def test_conversation_mode_per_task_no_memory(self): pipeline = Pipeline( - conversation_memory_strategy=Structure.ConversationMemoryStrategy.PER_TASK, + conversation_memory_strategy="per_task", tasks=[PromptTask(conversation_memory=None), PromptTask("2")], ) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 7b957d8fc..480b7df39 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -1004,7 +1004,7 @@ def test_to_dict(self): "meta": workflow.conversation_memory.meta, "max_runs": workflow.conversation_memory.max_runs, }, - "conversation_memory_strategy": str(workflow.conversation_memory_strategy), + "conversation_memory_strategy": workflow.conversation_memory_strategy, "fail_fast": workflow.fail_fast, } assert workflow.to_dict() == expected_workflow_dict