From 26768e7cdc8f5eebd0276fb8761b9bff1d4ff286 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Mon, 21 Oct 2024 14:42:37 -0700 Subject: [PATCH] Standardize callables --- CHANGELOG.md | 21 ++++++++++++ MIGRATION.md | 5 +++ docs/examples/src/multi_agent_workflow_1.py | 4 +-- .../drivers/src/sql_drivers_3.py | 2 +- .../drivers/src/structure_run_drivers_1.py | 4 +-- .../drivers/src/structure_run_drivers_2.py | 2 +- .../engines/summary-engines.md | 2 +- docs/griptape-framework/misc/events.md | 12 +++---- docs/griptape-framework/misc/src/events_1.py | 4 +-- docs/griptape-framework/misc/src/events_6.py | 4 +-- .../structures/src/tasks_10.py | 2 +- .../structures/src/tasks_16.py | 4 +-- griptape/artifacts/base_artifact.py | 2 +- griptape/drivers/sql/snowflake_sql_driver.py | 12 +++---- .../local_structure_run_driver.py | 8 ++--- .../vector/local_vector_store_driver.py | 4 +-- .../extraction/csv_extraction_engine.py | 14 ++++---- .../extraction/json_extraction_engine.py | 12 +++---- .../query/translate_query_rag_module.py | 4 +-- .../footnote_prompt_response_rag_module.py | 2 +- .../response/prompt_response_rag_module.py | 4 +-- .../text_loader_retrieval_rag_module.py | 4 +-- .../vector_store_retrieval_rag_module.py | 4 +-- .../engines/summary/prompt_summary_engine.py | 10 +++--- griptape/events/event_listener.py | 12 +++---- griptape/loaders/csv_loader.py | 4 +-- griptape/loaders/sql_loader.py | 4 +-- .../structure/summary_conversation_memory.py | 8 ++--- griptape/mixins/futures_executor_mixin.py | 4 +-- griptape/rules/json_schema_rule.py | 6 ++-- griptape/tasks/code_execution_task.py | 4 +-- griptape/tasks/prompt_task.py | 4 +-- griptape/tasks/tool_task.py | 2 +- griptape/tasks/toolkit_task.py | 10 +++--- griptape/tools/vector_store/tool.py | 6 ++-- griptape/utils/chat.py | 32 +++++++++---------- griptape/utils/stream.py | 2 +- tests/integration/rules/test_rule.py | 2 +- .../tasks/test_csv_extraction_task.py | 2 +- .../tasks/test_json_extraction_task.py | 2 +- tests/integration/tasks/test_prompt_task.py | 4 ++- tests/integration/tasks/test_rag_task.py | 2 +- .../tasks/test_text_summary_task.py | 2 +- tests/integration/tasks/test_tool_task.py | 4 ++- tests/integration/tasks/test_toolkit_task.py | 2 +- .../integration/tools/test_calculator_tool.py | 2 +- .../tools/test_file_manager_tool.py | 2 +- .../tools/test_google_docs_tool.py | 2 +- .../tools/test_google_drive_tool.py | 2 +- tests/mocks/mock_event_listener_driver.py | 12 +++---- .../test_base_audio_transcription_driver.py | 2 +- .../test_base_event_listener_driver.py | 8 ++--- .../test_base_image_generation_driver.py | 8 ++--- .../test_base_image_query_driver.py | 2 +- ...est_hugging_face_pipeline_prompt_driver.py | 18 +++++------ .../drivers/sql/test_snowflake_sql_driver.py | 12 +++---- .../test_local_structure_run_driver.py | 4 +-- .../test_base_audio_transcription_driver.py | 2 +- ...est_footnote_prompt_response_rag_module.py | 2 +- .../test_prompt_response_rag_module.py | 2 +- tests/unit/events/test_event_bus.py | 8 ++--- tests/unit/events/test_event_listener.py | 2 +- .../mixins/test_futures_executor_mixin.py | 2 +- tests/unit/structures/test_pipeline.py | 4 +-- tests/unit/structures/test_workflow.py | 4 +-- tests/unit/tasks/test_base_task.py | 6 ++-- tests/unit/tasks/test_code_execution_task.py | 6 ++-- tests/unit/tasks/test_structure_run_task.py | 4 +-- tests/unit/tools/test_structure_run_tool.py | 4 +-- tests/unit/tools/test_vector_store_tool.py | 4 +-- tests/unit/utils/test_chat.py | 16 +++++----- tests/utils/structure_tester.py | 2 +- 72 files changed, 216 insertions(+), 190 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c13c8955..9d1a34cbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,27 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed `BaseTask.can_execute` to `BaseTool.can_run`. - **BREAKING**: Renamed `BaseTool.run` to `BaseTool.try_run`. - **BREAKING**: Renamed `BaseTool.execute` to `BaseTool.run`. +- **BREAKING**: Renamed callables throughout the framework for consistency: + - Renamed `LocalStructureRunDriver.structure_factory_fn` to `LocalStructureRunDriver.create_structure`. + - Renamed `SnowflakeSqlDriver.connection_func` to `SnowflakeSqlDriver.get_connection`. + - Renamed `CsvLoader.formatter_fn` to `CsvLoader.format_row`. + - Renamed `SqlLoader.formatter_fn` to `SqlLoader.format_row`. + - Renamed `CsvExtractionEngine.system_template_generator` to `CsvExtractionEngine.generate_system_template`. + - Renamed `CsvExtractionEngine.user_template_generator` to `CsvExtractionEngine.generate_user_template`. + - Renamed `JsonExtractionEngine.system_template_generator` to `JsonExtractionEngine.generate_system_template`. + - Renamed `JsonExtractionEngine.user_template_generator` to `JsonExtractionEngine.generate_user_template`. + - Renamed `PromptResponseRagModule.generate_system_template` to `PromptResponseRagModule.generate_system_template`. + - Renamed `PromptTask.generate_system_template` to `PromptTask.generate_system_template`. + - Renamed `ToolkitTask.generate_assistant_subtask_template` to `ToolkitTask.generate_assistant_subtask_template`. + - Renamed `JsonSchemaRule.template_generator` to `JsonSchemaRule.generate_template`. + - Renamed `ToolkitTask.generate_user_subtask_template` to `ToolkitTask.generate_user_subtask_template`. + - Renamed `TextLoaderRetrievalRagModule.process_query_output_fn` to `TextLoaderRetrievalRagModule.process_query_output`. + - Renamed `FuturesExecutorMixin.futures_executor_fn` to `FuturesExecutorMixin.create_futures_executor`. + - Renamed `VectorStoreTool.process_query_output_fn` to `VectorStoreTool.process_query_output`. + - Renamed `CodeExecutionTask.run_fn` to `CodeExecutionTask.on_run`. + - Renamed `Chat.input_fn` to `Chat.handle_input`. + - Renamed `Chat.output_fn` to `Chat.handle_output`. + - Renamed `EventListener.handler` to `EventListener.on_event`. - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. - `BaseTask.parent_outputs` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. - `Workflow.context["parent_outputs"]` type has changed from `dict[str, str | None]` to `dict[str, BaseArtifact]`. diff --git a/MIGRATION.md b/MIGRATION.md index 79bce7a33..1764a7787 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -35,6 +35,11 @@ Defaults.drivers_config = AnthropicDriversConfig( ) ``` +### Renamed Callables + +Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth. + + ### Removed `CompletionChunkEvent` `CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`. diff --git a/docs/examples/src/multi_agent_workflow_1.py b/docs/examples/src/multi_agent_workflow_1.py index ad9436a55..ea880435a 100644 --- a/docs/examples/src/multi_agent_workflow_1.py +++ b/docs/examples/src/multi_agent_workflow_1.py @@ -133,7 +133,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: ), id="research", driver=LocalStructureRunDriver( - structure_factory_fn=build_researcher, + create_structure=build_researcher, ), ), ) @@ -150,7 +150,7 @@ def build_writer(role: str, goal: str, backstory: str) -> Agent: {{ parent_outputs["research"] }}""", ), driver=LocalStructureRunDriver( - structure_factory_fn=lambda writer=writer: build_writer( + create_structure=lambda writer=writer: build_writer( role=writer["role"], goal=writer["goal"], backstory=writer["backstory"], diff --git a/docs/griptape-framework/drivers/src/sql_drivers_3.py b/docs/griptape-framework/drivers/src/sql_drivers_3.py index 29ee4a818..cf1e7c1dc 100644 --- a/docs/griptape-framework/drivers/src/sql_drivers_3.py +++ b/docs/griptape-framework/drivers/src/sql_drivers_3.py @@ -17,6 +17,6 @@ def get_snowflake_connection() -> SnowflakeConnection: ) -driver = SnowflakeSqlDriver(connection_func=get_snowflake_connection) +driver = SnowflakeSqlDriver(get_connection=get_snowflake_connection) driver.execute_query("select * from people;") diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py index a29bfbedf..5fc5b29fe 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_1.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_1.py @@ -28,13 +28,13 @@ def build_joke_rewriter() -> Agent: tasks=[ StructureRunTask( driver=LocalStructureRunDriver( - structure_factory_fn=build_joke_teller, + create_structure=build_joke_teller, ), ), StructureRunTask( ("Rewrite this joke: {{ parent_output }}",), driver=LocalStructureRunDriver( - structure_factory_fn=build_joke_rewriter, + create_structure=build_joke_rewriter, ), ), ] diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py index 6103a6507..bec40c6ee 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py @@ -15,7 +15,7 @@ StructureRunTask( ("Think of a question related to Retrieval Augmented Generation.",), driver=LocalStructureRunDriver( - structure_factory_fn=lambda: Agent( + create_structure=lambda: Agent( rules=[ Rule( value="You are an expert in Retrieval Augmented Generation.", diff --git a/docs/griptape-framework/engines/summary-engines.md b/docs/griptape-framework/engines/summary-engines.md index cfd99f3a8..573d56dad 100644 --- a/docs/griptape-framework/engines/summary-engines.md +++ b/docs/griptape-framework/engines/summary-engines.md @@ -9,7 +9,7 @@ Summary engines are used to summarize text and collections of [TextArtifact](../ ## Prompt -Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [system_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.system_template_generator), [user_template_generator](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.user_template_generator), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker). +Used to summarize texts with LLMs. You can set a custom [prompt_driver](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.prompt_driver), [generate_system_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_system_template), [generate_user_template](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.generate_user_template), and [chunker](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.chunker). Use the [summarize_artifacts](../../reference/griptape/engines/summary/prompt_summary_engine.md#griptape.engines.summary.prompt_summary_engine.PromptSummaryEngine.summarize_artifacts) method to summarize a list of artifacts or [summarize_text](../../reference/griptape/engines/summary/base_summary_engine.md#griptape.engines.summary.base_summary_engine.BaseSummaryEngine.summarize_text) to summarize an arbitrary string. diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index b97b9de98..c33e3a3c6 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -159,21 +159,21 @@ Assistant: ... ``` -## `EventListenerDriver.handler` Return Value Behavior +## `EventListenerDriver.on_event` Return Value Behavior -The value that gets returned from the [`EventListener.handler`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.handler) will determine what gets sent to the `event_listener_driver`. +The value that gets returned from the [`EventListener.on_event`](../../reference/griptape/events/event_listener.md#griptape.events.event_listener.EventListener.on_event) will determine what gets sent to the `event_listener_driver`. -### `EventListener.handler` is None +### `EventListener.on_event` is None -By default, the `EventListener.handler` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is. +By default, the `EventListener.on_event` function is `None`. Any events that the `EventListener` is listening for will get sent to the `event_listener_driver` as-is. ### Return `BaseEvent` or `dict` -You can return a `BaseEvent` or `dict` object from `EventListener.handler`, and it will get sent to the `event_listener_driver`. +You can return a `BaseEvent` or `dict` object from `EventListener.on_event`, and it will get sent to the `event_listener_driver`. ### Return `None` -You can return `None` in the handler function to prevent the event from getting sent to the `event_listener_driver`. +You can return `None` in the on_event function to prevent the event from getting sent to the `event_listener_driver`. ```python --8<-- "docs/griptape-framework/misc/src/events_no_publish.py" diff --git a/docs/griptape-framework/misc/src/events_1.py b/docs/griptape-framework/misc/src/events_1.py index 993567cc6..ad9cb5647 100644 --- a/docs/griptape-framework/misc/src/events_1.py +++ b/docs/griptape-framework/misc/src/events_1.py @@ -12,14 +12,14 @@ from griptape.structures import Agent -def handler(event: BaseEvent) -> None: +def on_event(event: BaseEvent) -> None: print(event.__class__) EventBus.add_event_listeners( [ EventListener( - handler, + on_event, event_types=[ StartTaskEvent, FinishTaskEvent, diff --git a/docs/griptape-framework/misc/src/events_6.py b/docs/griptape-framework/misc/src/events_6.py index 0bfa9426b..4cc21fa75 100644 --- a/docs/griptape-framework/misc/src/events_6.py +++ b/docs/griptape-framework/misc/src/events_6.py @@ -2,14 +2,14 @@ from griptape.structures import Agent -def handler(event: BaseEvent) -> None: +def on_event(event: BaseEvent) -> None: if isinstance(event, StartPromptEvent): print("Prompt Stack Messages:") for message in event.prompt_stack.messages: print(f"{message.role}: {message.to_text()}") -EventBus.add_event_listeners([EventListener(handler=handler, event_types=[StartPromptEvent])]) +EventBus.add_event_listeners([EventListener(on_event=on_event, event_types=[StartPromptEvent])]) agent = Agent() diff --git a/docs/griptape-framework/structures/src/tasks_10.py b/docs/griptape-framework/structures/src/tasks_10.py index c94fa7919..e36a843df 100644 --- a/docs/griptape-framework/structures/src/tasks_10.py +++ b/docs/griptape-framework/structures/src/tasks_10.py @@ -14,7 +14,7 @@ def character_counter(task: CodeExecutionTask) -> BaseArtifact: pipeline.add_tasks( # take the first argument from the pipeline `run` method - CodeExecutionTask(run_fn=character_counter), + CodeExecutionTask(on_run=character_counter), # # take the output from the previous task and insert it into the prompt PromptTask("{{args[0]}} using {{ parent_output }} characters"), ) diff --git a/docs/griptape-framework/structures/src/tasks_16.py b/docs/griptape-framework/structures/src/tasks_16.py index 7496d2d9c..332187a00 100644 --- a/docs/griptape-framework/structures/src/tasks_16.py +++ b/docs/griptape-framework/structures/src/tasks_16.py @@ -112,7 +112,7 @@ def build_writer() -> Agent: """Perform a detailed examination of the newest developments in AI as of 2024. Pinpoint major trends, breakthroughs, and their implications for various industries.""", ), - driver=LocalStructureRunDriver(structure_factory_fn=build_researcher), + driver=LocalStructureRunDriver(create_structure=build_researcher), ), StructureRunTask( ( @@ -122,7 +122,7 @@ def build_writer() -> Agent: Keep the tone appealing and use simple language to make it less technical.""", "{{parent_output}}", ), - driver=LocalStructureRunDriver(structure_factory_fn=build_writer), + driver=LocalStructureRunDriver(create_structure=build_writer), ), ], ) diff --git a/griptape/artifacts/base_artifact.py b/griptape/artifacts/base_artifact.py index 61989ab54..4eb908251 100644 --- a/griptape/artifacts/base_artifact.py +++ b/griptape/artifacts/base_artifact.py @@ -25,7 +25,7 @@ class BaseArtifact(SerializableMixin, ABC): name: The name of the Artifact. Defaults to the id. value: The value of the Artifact. encoding: The encoding to use when encoding/decoding the value. - encoding_error_handler: The error handler to use when encoding/decoding the value. + encoding_error_handler: The error on_event to use when encoding/decoding the value. """ id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) diff --git a/griptape/drivers/sql/snowflake_sql_driver.py b/griptape/drivers/sql/snowflake_sql_driver.py index d1b4310b5..82d6a525c 100644 --- a/griptape/drivers/sql/snowflake_sql_driver.py +++ b/griptape/drivers/sql/snowflake_sql_driver.py @@ -15,16 +15,16 @@ @define class SnowflakeSqlDriver(BaseSqlDriver): - connection_func: Callable[[], SnowflakeConnection] = field(kw_only=True) + get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True) _engine: Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) - @connection_func.validator # pyright: ignore[reportFunctionMemberAccess] - def validate_connection_func(self, _: Attribute, connection_func: Callable[[], SnowflakeConnection]) -> None: - snowflake_connection = connection_func() + @get_connection.validator # pyright: ignore[reportFunctionMemberAccess] + def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None: + snowflake_connection = get_connection() snowflake = import_optional_dependency("snowflake") if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection): - raise ValueError("The connection_func must return a SnowflakeConnection") + raise ValueError("The get_connection function must return a SnowflakeConnection") if not snowflake_connection.schema or not snowflake_connection.database: raise ValueError("Provide a schema and database for the Snowflake connection") @@ -32,7 +32,7 @@ def validate_connection_func(self, _: Attribute, connection_func: Callable[[], S def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine( "snowflake://not@used/db", - creator=self.connection_func, + creator=self.get_connection, ) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: diff --git a/griptape/drivers/structure_run/local_structure_run_driver.py b/griptape/drivers/structure_run/local_structure_run_driver.py index c0049b29a..b2335e3c3 100644 --- a/griptape/drivers/structure_run/local_structure_run_driver.py +++ b/griptape/drivers/structure_run/local_structure_run_driver.py @@ -14,18 +14,18 @@ @define class LocalStructureRunDriver(BaseStructureRunDriver): - structure_factory_fn: Callable[[], Structure] = field(kw_only=True) + create_structure: Callable[[], Structure] = field(kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) - structure_factory_fn = self.structure_factory_fn().run(*[arg.value for arg in args]) + structure = self.create_structure().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env) - if structure_factory_fn.output_task.output is not None: - return structure_factory_fn.output_task.output + if structure.output_task.output is not None: + return structure.output_task.output else: return InfoArtifact("No output found in response") diff --git a/griptape/drivers/vector/local_vector_store_driver.py b/griptape/drivers/vector/local_vector_store_driver.py index 36203d540..557937431 100644 --- a/griptape/drivers/vector/local_vector_store_driver.py +++ b/griptape/drivers/vector/local_vector_store_driver.py @@ -19,7 +19,7 @@ class LocalVectorStoreDriver(BaseVectorStoreDriver): entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict) persist_file: Optional[str] = field(default=None) - relatedness_fn: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) + calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) def __attrs_post_init__(self) -> None: @@ -95,7 +95,7 @@ def query( entries = self.entries entries_and_relatednesses = [ - (entry, self.relatedness_fn(query_embedding, entry.vector)) for entry in list(entries.values()) + (entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values()) ] entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 7fb2a164b..7f4647d65 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -18,9 +18,9 @@ @define class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(kw_only=True) - system_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) - formatter_fn: Callable[[dict], str] = field( + generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) + generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) + format_row: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -45,7 +45,7 @@ def text_to_csv_rows(self, text: str, column_names: list[str]) -> list[TextArtif with io.StringIO(text) as f: for row in csv.reader(f): - rows.append(TextArtifact(self.formatter_fn(dict(zip(column_names, [x.strip() for x in row]))))) + rows.append(TextArtifact(self.format_row(dict(zip(column_names, [x.strip() for x in row]))))) return rows @@ -57,11 +57,11 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.generate_system_template.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render( + user_prompt = self.generate_user_template.render( text=artifacts_text, ) @@ -86,7 +86,7 @@ def _extract_rec( return rows else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render( + partial_text = self.generate_user_template.render( text=chunks[0].value, ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index c817efd5f..f2c56a62b 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -21,10 +21,8 @@ class JsonExtractionEngine(BaseExtractionEngine): JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" template_schema: dict = field(kw_only=True) - system_template_generator: J2 = field( - default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True - ) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) + generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True) + generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) def extract_artifacts( self, @@ -54,11 +52,11 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[JsonArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.generate_system_template.render( json_template_schema=json.dumps(self.template_schema), rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render( + user_prompt = self.generate_user_template.render( text=artifacts_text, ) @@ -82,7 +80,7 @@ def _extract_rec( return extractions else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render( + partial_text = self.generate_user_template.render( text=chunks[0].value, ) diff --git a/griptape/engines/rag/modules/query/translate_query_rag_module.py b/griptape/engines/rag/modules/query/translate_query_rag_module.py index f1f9ca0ec..e92d95e2b 100644 --- a/griptape/engines/rag/modules/query/translate_query_rag_module.py +++ b/griptape/engines/rag/modules/query/translate_query_rag_module.py @@ -17,7 +17,7 @@ class TranslateQueryRagModule(BaseQueryRagModule): prompt_driver: BasePromptDriver = field() language: str = field() generate_user_template: Callable[[str, str], str] = field( - default=Factory(lambda self: self.default_user_template_generator, takes_self=True), + default=Factory(lambda self: self.default_generate_user_template, takes_self=True), ) def run(self, context: RagContext) -> RagContext: @@ -28,5 +28,5 @@ def run(self, context: RagContext) -> RagContext: return context - def default_user_template_generator(self, query: str, language: str) -> str: + def default_generate_user_template(self, query: str, language: str) -> str: return J2("engines/rag/modules/query/translate/user.j2").render(query=query, language=language) diff --git a/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py b/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py index 3687d1942..ea07c5007 100644 --- a/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/footnote_prompt_response_rag_module.py @@ -15,7 +15,7 @@ @define(kw_only=True) class FootnotePromptResponseRagModule(PromptResponseRagModule): - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: + def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: return J2("engines/rag/modules/response/footnote_prompt/system.j2").render( text_chunk_artifacts=artifacts, references=utils.references_from_artifacts(artifacts), diff --git a/griptape/engines/rag/modules/response/prompt_response_rag_module.py b/griptape/engines/rag/modules/response/prompt_response_rag_module.py index b62a0eba3..2e4f39947 100644 --- a/griptape/engines/rag/modules/response/prompt_response_rag_module.py +++ b/griptape/engines/rag/modules/response/prompt_response_rag_module.py @@ -22,7 +22,7 @@ class PromptResponseRagModule(BaseResponseRagModule, RuleMixin): answer_token_offset: int = field(default=400) metadata: Optional[str] = field(default=None) generate_system_template: Callable[[RagContext, list[TextArtifact]], str] = field( - default=Factory(lambda self: self.default_system_template_generator, takes_self=True), + default=Factory(lambda self: self.default_generate_system_template, takes_self=True), ) def run(self, context: RagContext) -> BaseArtifact: @@ -53,7 +53,7 @@ def run(self, context: RagContext) -> BaseArtifact: else: raise ValueError("Prompt driver did not return a TextArtifact") - def default_system_template_generator(self, context: RagContext, artifacts: list[TextArtifact]) -> str: + def default_generate_system_template(self, context: RagContext, artifacts: list[TextArtifact]) -> str: params: dict[str, Any] = {"text_chunks": [c.to_text() for c in artifacts]} if len(self.rulesets) > 0: diff --git a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py index 0348a2094..46128c9f6 100644 --- a/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/text_loader_retrieval_rag_module.py @@ -25,7 +25,7 @@ class TextLoaderRetrievalRagModule(BaseRetrievalRagModule): vector_store_driver: BaseVectorStoreDriver = field() source: Any = field() query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( + process_query_output: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) @@ -43,4 +43,4 @@ def run(self, context: RagContext) -> Sequence[TextArtifact]: self.vector_store_driver.upsert_text_artifacts({namespace: chunks}) - return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) + return self.process_query_output(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py index ddff2549c..42ae5876f 100644 --- a/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py +++ b/griptape/engines/rag/modules/retrieval/vector_store_retrieval_rag_module.py @@ -22,11 +22,11 @@ class VectorStoreRetrievalRagModule(BaseRetrievalRagModule): default=Factory(lambda: Defaults.drivers_config.vector_store_driver) ) query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( + process_query_output: Callable[[list[BaseVectorStoreDriver.Entry]], Sequence[TextArtifact]] = field( default=Factory(lambda: lambda es: [e.to_artifact() for e in es]), ) def run(self, context: RagContext) -> Sequence[TextArtifact]: query_params = utils.dict_merge(self.query_params, self.get_context_param(context, "query_params")) - return self.process_query_output_fn(self.vector_store_driver.query(context.query, **query_params)) + return self.process_query_output(self.vector_store_driver.query(context.query, **query_params)) diff --git a/griptape/engines/summary/prompt_summary_engine.py b/griptape/engines/summary/prompt_summary_engine.py index 3cc3dd470..29b7e97af 100644 --- a/griptape/engines/summary/prompt_summary_engine.py +++ b/griptape/engines/summary/prompt_summary_engine.py @@ -20,8 +20,8 @@ class PromptSummaryEngine(BaseSummaryEngine): chunk_joiner: str = field(default="\n\n", kw_only=True) max_token_multiplier: float = field(default=0.5, kw_only=True) - system_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) - user_template_generator: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) + generate_system_template: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) + generate_user_template: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) @@ -67,12 +67,12 @@ def summarize_artifacts_rec( artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts]) - system_prompt = self.system_template_generator.render( + system_prompt = self.generate_system_template.render( summary=summary, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) - user_prompt = self.user_template_generator.render(text=artifacts_text) + user_prompt = self.generate_user_template.render(text=artifacts_text) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) @@ -94,7 +94,7 @@ def summarize_artifacts_rec( else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.user_template_generator.render(text=chunks[0].value) + partial_text = self.generate_user_template.render(text=chunks[0].value) return self.summarize_artifacts_rec( chunks[1:], diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index bbca5f83b..a7eaf3ab1 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -18,15 +18,15 @@ class EventListener(Generic[T]): """An event listener that listens for events and handles them. Attributes: - handler: The handler function that will be called when an event is published. - The handler function should accept an event and return either the event or a dictionary. - If the handler returns None, the event will not be published. + on_event: The on_event function that will be called when an event is published. + The on_event function should accept an event and return either the event or a dictionary. + If the on_event returns None, the event will not be published. event_types: A list of event types that the event listener should listen for. If not provided, the event listener will listen for all event types. event_listener_driver: The driver that will be used to publish events. """ - handler: Optional[Callable[[T], Optional[BaseEvent | dict]]] = field(default=None) + on_event: Optional[Callable[[T], Optional[BaseEvent | dict]]] = field(default=None) event_types: Optional[list[type[T]]] = field(default=None, kw_only=True) event_listener_driver: Optional[BaseEventListenerDriver] = field(default=None, kw_only=True) @@ -47,8 +47,8 @@ def publish_event(self, event: T, *, flush: bool = False) -> None: if event_types is None or any(isinstance(event, event_type) for event_type in event_types): handled_event = event - if self.handler is not None: - handled_event = self.handler(event) + if self.on_event is not None: + handled_event = self.on_event(event) if self.event_listener_driver is not None and handled_event is not None: self.event_listener_driver.publish_event(handled_event) diff --git a/griptape/loaders/csv_loader.py b/griptape/loaders/csv_loader.py index 4487d7aec..c7e3a139e 100644 --- a/griptape/loaders/csv_loader.py +++ b/griptape/loaders/csv_loader.py @@ -14,7 +14,7 @@ class CsvLoader(BaseFileLoader[ListArtifact[TextArtifact]]): delimiter: str = field(default=",", kw_only=True) encoding: str = field(default="utf-8", kw_only=True) - formatter_fn: Callable[[dict], str] = field( + format_row: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -22,5 +22,5 @@ def parse(self, data: bytes) -> ListArtifact[TextArtifact]: reader = csv.DictReader(StringIO(data.decode(self.encoding)), delimiter=self.delimiter) return ListArtifact( - [TextArtifact(self.formatter_fn(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] + [TextArtifact(self.format_row(row), meta={"row_num": row_num}) for row_num, row in enumerate(reader)] ) diff --git a/griptape/loaders/sql_loader.py b/griptape/loaders/sql_loader.py index 0c6e8bdf9..e63f7af81 100644 --- a/griptape/loaders/sql_loader.py +++ b/griptape/loaders/sql_loader.py @@ -12,7 +12,7 @@ @define class SqlLoader(BaseLoader[str, list[BaseSqlDriver.RowResult], ListArtifact[TextArtifact]]): sql_driver: BaseSqlDriver = field(kw_only=True) - formatter_fn: Callable[[dict], str] = field( + format_row: Callable[[dict], str] = field( default=lambda value: "\n".join(f"{key}: {val}" for key, val in value.items()), kw_only=True ) @@ -20,4 +20,4 @@ def fetch(self, source: str) -> list[BaseSqlDriver.RowResult]: return self.sql_driver.execute_query(source) or [] def parse(self, data: list[BaseSqlDriver.RowResult]) -> ListArtifact[TextArtifact]: - return ListArtifact([TextArtifact(self.formatter_fn(row.cells)) for row in data]) + return ListArtifact([TextArtifact(self.format_row(row.cells)) for row in data]) diff --git a/griptape/memory/structure/summary_conversation_memory.py b/griptape/memory/structure/summary_conversation_memory.py index 055057d34..5a1c5363d 100644 --- a/griptape/memory/structure/summary_conversation_memory.py +++ b/griptape/memory/structure/summary_conversation_memory.py @@ -23,8 +23,8 @@ class SummaryConversationMemory(ConversationMemory): ) summary: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) summary_index: int = field(default=0, kw_only=True, metadata={"serializable": True}) - summary_template_generator: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) - summarize_conversation_template_generator: J2 = field( + summary_get_template: J2 = field(default=Factory(lambda: J2("memory/conversation/summary.j2")), kw_only=True) + summarize_conversation_get_template: J2 = field( default=Factory(lambda: J2("memory/conversation/summarize_conversation.j2")), kw_only=True, ) @@ -32,7 +32,7 @@ class SummaryConversationMemory(ConversationMemory): def to_prompt_stack(self, last_n: Optional[int] = None) -> PromptStack: stack = PromptStack() if self.summary: - stack.add_user_message(self.summary_template_generator.render(summary=self.summary)) + stack.add_user_message(self.summary_get_template.render(summary=self.summary)) for r in self.unsummarized_runs(last_n): stack.add_user_message(r.input) @@ -66,7 +66,7 @@ def try_add_run(self, run: Run) -> None: def summarize_runs(self, previous_summary: str | None, runs: list[Run]) -> str | None: try: if len(runs) > 0: - summary = self.summarize_conversation_template_generator.render(summary=previous_summary, runs=runs) + summary = self.summarize_conversation_get_template.render(summary=previous_summary, runs=runs) return self.prompt_driver.run( prompt_stack=PromptStack(messages=[Message(summary, role=Message.USER_ROLE)]), ).to_text() diff --git a/griptape/mixins/futures_executor_mixin.py b/griptape/mixins/futures_executor_mixin.py index 8c309d9b7..5f3eb5324 100644 --- a/griptape/mixins/futures_executor_mixin.py +++ b/griptape/mixins/futures_executor_mixin.py @@ -10,12 +10,12 @@ @define(slots=False, kw_only=True) class FuturesExecutorMixin(ABC): - futures_executor_fn: Callable[[], futures.Executor] = field( + create_futures_executor: Callable[[], futures.Executor] = field( default=Factory(lambda: lambda: futures.ThreadPoolExecutor()), ) futures_executor: Optional[futures.Executor] = field( - default=Factory(lambda self: self.futures_executor_fn(), takes_self=True) + default=Factory(lambda self: self.create_futures_executor(), takes_self=True) ) def __del__(self) -> None: diff --git a/griptape/rules/json_schema_rule.py b/griptape/rules/json_schema_rule.py index 1bd418464..ce41f26db 100644 --- a/griptape/rules/json_schema_rule.py +++ b/griptape/rules/json_schema_rule.py @@ -2,7 +2,7 @@ import json -from attrs import define, field +from attrs import Factory, define, field from griptape.rules import BaseRule from griptape.utils import J2 @@ -11,7 +11,7 @@ @define(frozen=True) class JsonSchemaRule(BaseRule): value: dict = field() - template_generator: J2 = field(default=J2("rules/json_schema.j2")) + generate_template: J2 = field(default=Factory(lambda: J2("rules/json_schema.j2"))) def to_text(self) -> str: - return self.template_generator.render(json_schema=json.dumps(self.value)) + return self.generate_template.render(json_schema=json.dumps(self.value)) diff --git a/griptape/tasks/code_execution_task.py b/griptape/tasks/code_execution_task.py index 390d08e91..5ce311be7 100644 --- a/griptape/tasks/code_execution_task.py +++ b/griptape/tasks/code_execution_task.py @@ -12,7 +12,7 @@ @define class CodeExecutionTask(BaseTextInputTask): - run_fn: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True) + on_run: Callable[[CodeExecutionTask], BaseArtifact] = field(kw_only=True) def try_run(self) -> BaseArtifact: - return self.run_fn(self) + return self.on_run(self) diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 98eb9e309..598eacf57 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -25,7 +25,7 @@ class PromptTask(RuleMixin, BaseTask): default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) generate_system_template: Callable[[PromptTask], str] = field( - default=Factory(lambda self: self.default_system_template_generator, takes_self=True), + default=Factory(lambda self: self.default_generate_system_template, takes_self=True), kw_only=True, ) _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( @@ -79,7 +79,7 @@ def prompt_stack(self) -> PromptStack: return stack - def default_system_template_generator(self, _: PromptTask) -> str: + def default_generate_system_template(self, _: PromptTask) -> str: return J2("tasks/prompt_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), ) diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 325400f5c..07b762167 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -49,7 +49,7 @@ def preprocess(self, structure: Structure) -> ToolTask: return self - def default_system_template_generator(self, _: PromptTask) -> str: + def default_generate_system_template(self, _: PromptTask) -> str: return J2("tasks/tool_task/system.j2").render( rulesets=J2("rulesets/rulesets.j2").render(rulesets=self.rulesets), action_schema=utils.minify_json(json.dumps(self.tool.schema())), diff --git a/griptape/tasks/toolkit_task.py b/griptape/tasks/toolkit_task.py index 98bbff9a1..088ccd52d 100644 --- a/griptape/tasks/toolkit_task.py +++ b/griptape/tasks/toolkit_task.py @@ -32,11 +32,11 @@ class ToolkitTask(PromptTask, ActionsSubtaskOriginMixin): task_memory: Optional[TaskMemory] = field(default=None, kw_only=True) subtasks: list[ActionsSubtask] = field(factory=list) generate_assistant_subtask_template: Callable[[ActionsSubtask], str] = field( - default=Factory(lambda self: self.default_assistant_subtask_template_generator, takes_self=True), + default=Factory(lambda self: self.default_generate_assistant_subtask_template, takes_self=True), kw_only=True, ) generate_user_subtask_template: Callable[[ActionsSubtask], str] = field( - default=Factory(lambda self: self.default_user_subtask_template_generator, takes_self=True), + default=Factory(lambda self: self.default_generate_user_subtask_template, takes_self=True), kw_only=True, ) response_stop_sequence: str = field(default=RESPONSE_STOP_SEQUENCE, kw_only=True) @@ -127,7 +127,7 @@ def preprocess(self, structure: Structure) -> ToolkitTask: return self - def default_system_template_generator(self, _: PromptTask) -> str: + def default_generate_system_template(self, _: PromptTask) -> str: schema = self.actions_schema().json_schema("Actions Schema") schema["minItems"] = 1 # The `schema` library doesn't support `minItems` so we must add it manually. @@ -140,13 +140,13 @@ def default_system_template_generator(self, _: PromptTask) -> str: stop_sequence=self.response_stop_sequence, ) - def default_assistant_subtask_template_generator(self, subtask: ActionsSubtask) -> str: + def default_generate_assistant_subtask_template(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/assistant_subtask.j2").render( stop_sequence=self.response_stop_sequence, subtask=subtask, ) - def default_user_subtask_template_generator(self, subtask: ActionsSubtask) -> str: + def default_generate_user_subtask_template(self, subtask: ActionsSubtask) -> str: return J2("tasks/toolkit_task/user_subtask.j2").render( stop_sequence=self.response_stop_sequence, subtask=subtask, diff --git a/griptape/tools/vector_store/tool.py b/griptape/tools/vector_store/tool.py index 71902b1c7..eee854f6d 100644 --- a/griptape/tools/vector_store/tool.py +++ b/griptape/tools/vector_store/tool.py @@ -21,7 +21,7 @@ class VectorStoreTool(BaseTool): description: LLM-friendly vector DB description. vector_store_driver: `BaseVectorStoreDriver`. query_params: Optional dictionary of vector store driver query parameters. - process_query_output_fn: Optional lambda for processing vector store driver query output `Entry`s. + process_query_output: Optional lambda for processing vector store driver query output `Entry`s. """ DEFAULT_TOP_N = 5 @@ -29,7 +29,7 @@ class VectorStoreTool(BaseTool): description: str = field() vector_store_driver: BaseVectorStoreDriver = field() query_params: dict[str, Any] = field(factory=dict) - process_query_output_fn: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( + process_query_output: Callable[[list[BaseVectorStoreDriver.Entry]], BaseArtifact] = field( default=Factory(lambda: lambda es: ListArtifact([e.to_artifact() for e in es])), ) @@ -50,6 +50,6 @@ def search(self, params: dict) -> BaseArtifact: query = params["values"]["query"] try: - return self.process_query_output_fn(self.vector_store_driver.query(query, **self.query_params)) + return self.process_query_output(self.vector_store_driver.query(query, **self.query_params)) except Exception as e: return ErrorArtifact(f"error querying vector store: {e}") diff --git a/griptape/utils/chat.py b/griptape/utils/chat.py index 802fd809d..8bbf38cd7 100644 --- a/griptape/utils/chat.py +++ b/griptape/utils/chat.py @@ -25,8 +25,8 @@ class Chat: intro_text: Text to display when the chat starts. prompt_prefix: Prefix for the user's input. response_prefix: Prefix for the assistant's response. - input_fn: Function to get the user's input. - output_fn: Function to output text. Takes a `text` argument for the text to output. + handle_input: Function to get the user's input. + handle_output: Function to output text. Takes a `text` argument for the text to output. Also takes a `stream` argument which will be set to True when streaming Prompt Tasks are present. """ @@ -40,19 +40,19 @@ class ChatPrompt(Prompt): intro_text: Optional[str] = field(default=None, kw_only=True) prompt_prefix: str = field(default="User: ", kw_only=True) response_prefix: str = field(default="Assistant: ", kw_only=True) - input_fn: Callable[[str], str] = field( - default=Factory(lambda self: self.default_input_fn, takes_self=True), kw_only=True + handle_input: Callable[[str], str] = field( + default=Factory(lambda self: self.default_handle_input, takes_self=True), kw_only=True ) - output_fn: Callable[..., None] = field( - default=Factory(lambda self: self.default_output_fn, takes_self=True), + handle_output: Callable[..., None] = field( + default=Factory(lambda self: self.default_handle_output, takes_self=True), kw_only=True, ) logger_level: int = field(default=logging.ERROR, kw_only=True) - def default_input_fn(self, prompt_prefix: str) -> str: + def default_handle_input(self, prompt_prefix: str) -> str: return Chat.ChatPrompt.ask(prompt_prefix) - def default_output_fn(self, text: str, *, stream: bool = False) -> None: + def default_handle_output(self, text: str, *, stream: bool = False) -> None: if stream: rprint(text, end="", flush=True) else: @@ -66,26 +66,26 @@ def start(self) -> None: logging.getLogger(Defaults.logging_config.logger_name).setLevel(self.logger_level) if self.intro_text: - self.output_fn(self.intro_text) + self.handle_output(self.intro_text) has_streaming_tasks = self._has_streaming_tasks() while True: - question = self.input_fn(self.prompt_prefix) + question = self.handle_input(self.prompt_prefix) if question.lower() in self.exit_keywords: - self.output_fn(self.exiting_text) + self.handle_output(self.exiting_text) break if has_streaming_tasks: - self.output_fn(self.processing_text) + self.handle_output(self.processing_text) stream = Stream(self.structure).run(question) first_chunk = next(stream) - self.output_fn(self.response_prefix + first_chunk.value, stream=True) + self.handle_output(self.response_prefix + first_chunk.value, stream=True) for chunk in stream: - self.output_fn(chunk.value, stream=True) + self.handle_output(chunk.value, stream=True) else: - self.output_fn(self.processing_text) - self.output_fn(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}") + self.handle_output(self.processing_text) + self.handle_output(f"{self.response_prefix}{self.structure.run(question).output_task.output.to_text()}") # Restore the original logger level logging.getLogger(Defaults.logging_config.logger_name).setLevel(old_logger_level) diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index f722db33d..af8c65b3b 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -83,7 +83,7 @@ def event_handler(event: BaseEvent) -> None: self._event_queue.put(event) stream_event_listener = EventListener( - handler=event_handler, + on_event=event_handler, event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) EventBus.add_event_listener(stream_event_listener) diff --git a/tests/integration/rules/test_rule.py b/tests/integration/rules/test_rule.py index a62263c57..91c427653 100644 --- a/tests/integration/rules/test_rule.py +++ b/tests/integration/rules/test_rule.py @@ -5,7 +5,7 @@ class TestRule: @pytest.fixture( - autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, params=StructureTester.RULE_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.generate_prompt_driver_id ) def structure_tester(self, request): from griptape.rules import Rule diff --git a/tests/integration/tasks/test_csv_extraction_task.py b/tests/integration/tasks/test_csv_extraction_task.py index db58b9615..3e2186fae 100644 --- a/tests/integration/tasks/test_csv_extraction_task.py +++ b/tests/integration/tasks/test_csv_extraction_task.py @@ -7,7 +7,7 @@ class TestCsvExtractionTask: @pytest.fixture( autouse=True, params=StructureTester.CSV_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.engines import CsvExtractionEngine diff --git a/tests/integration/tasks/test_json_extraction_task.py b/tests/integration/tasks/test_json_extraction_task.py index 115f805da..e13fa7aa5 100644 --- a/tests/integration/tasks/test_json_extraction_task.py +++ b/tests/integration/tasks/test_json_extraction_task.py @@ -7,7 +7,7 @@ class TestJsonExtractionTask: @pytest.fixture( autouse=True, params=StructureTester.JSON_EXTRACTION_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from schema import Schema diff --git a/tests/integration/tasks/test_prompt_task.py b/tests/integration/tasks/test_prompt_task.py index 1d223b4ca..95106a9a0 100644 --- a/tests/integration/tasks/test_prompt_task.py +++ b/tests/integration/tasks/test_prompt_task.py @@ -5,7 +5,9 @@ class TestPromptTask: @pytest.fixture( - autouse=True, params=StructureTester.PROMPT_TASK_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, + params=StructureTester.PROMPT_TASK_CAPABLE_PROMPT_DRIVERS, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tasks/test_rag_task.py b/tests/integration/tasks/test_rag_task.py index ce3a9140d..255e608f3 100644 --- a/tests/integration/tasks/test_rag_task.py +++ b/tests/integration/tasks/test_rag_task.py @@ -9,7 +9,7 @@ class TestRagTask: @pytest.fixture( autouse=True, params=StructureTester.TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.artifacts import TextArtifact diff --git a/tests/integration/tasks/test_text_summary_task.py b/tests/integration/tasks/test_text_summary_task.py index ff6597ba0..811ec39f6 100644 --- a/tests/integration/tasks/test_text_summary_task.py +++ b/tests/integration/tasks/test_text_summary_task.py @@ -7,7 +7,7 @@ class TestTextSummaryTask: @pytest.fixture( autouse=True, params=StructureTester.TEXT_SUMMARY_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.engines.summary.prompt_summary_engine import PromptSummaryEngine diff --git a/tests/integration/tasks/test_tool_task.py b/tests/integration/tasks/test_tool_task.py index 426dde995..712e23d26 100644 --- a/tests/integration/tasks/test_tool_task.py +++ b/tests/integration/tasks/test_tool_task.py @@ -5,7 +5,9 @@ class TestToolTask: @pytest.fixture( - autouse=True, params=StructureTester.TOOL_TASK_CAPABLE_PROMPT_DRIVERS, ids=StructureTester.prompt_driver_id_fn + autouse=True, + params=StructureTester.TOOL_TASK_CAPABLE_PROMPT_DRIVERS, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tasks/test_toolkit_task.py b/tests/integration/tasks/test_toolkit_task.py index 50b4f2a97..7593c5391 100644 --- a/tests/integration/tasks/test_toolkit_task.py +++ b/tests/integration/tasks/test_toolkit_task.py @@ -7,7 +7,7 @@ class TestToolkitTask: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): import os diff --git a/tests/integration/tools/test_calculator_tool.py b/tests/integration/tools/test_calculator_tool.py index c209a9a2c..634b84803 100644 --- a/tests/integration/tools/test_calculator_tool.py +++ b/tests/integration/tools/test_calculator_tool.py @@ -7,7 +7,7 @@ class TestCalculator: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_file_manager_tool.py b/tests/integration/tools/test_file_manager_tool.py index 4b5299175..ce6b331c2 100644 --- a/tests/integration/tools/test_file_manager_tool.py +++ b/tests/integration/tools/test_file_manager_tool.py @@ -7,7 +7,7 @@ class TestFileManager: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_google_docs_tool.py b/tests/integration/tools/test_google_docs_tool.py index 7c8828dd3..e977e5523 100644 --- a/tests/integration/tools/test_google_docs_tool.py +++ b/tests/integration/tools/test_google_docs_tool.py @@ -9,7 +9,7 @@ class TestGoogleDocsTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/integration/tools/test_google_drive_tool.py b/tests/integration/tools/test_google_drive_tool.py index 7fd8b9047..fdd9fde89 100644 --- a/tests/integration/tools/test_google_drive_tool.py +++ b/tests/integration/tools/test_google_drive_tool.py @@ -9,7 +9,7 @@ class TestGoogleDriveTool: @pytest.fixture( autouse=True, params=StructureTester.TOOLKIT_TASK_CAPABLE_PROMPT_DRIVERS, - ids=StructureTester.prompt_driver_id_fn, + ids=StructureTester.generate_prompt_driver_id, ) def structure_tester(self, request): from griptape.structures import Agent diff --git a/tests/mocks/mock_event_listener_driver.py b/tests/mocks/mock_event_listener_driver.py index e56d35e90..1a17d5e69 100644 --- a/tests/mocks/mock_event_listener_driver.py +++ b/tests/mocks/mock_event_listener_driver.py @@ -9,13 +9,13 @@ @define class MockEventListenerDriver(BaseEventListenerDriver): - try_publish_event_payload_fn: Optional[Callable[[dict], None]] = field(default=None, kw_only=True) - try_publish_event_payload_batch_fn: Optional[Callable[[list[dict]], None]] = field(default=None, kw_only=True) + on_event_payload_publish: Optional[Callable[[dict], None]] = field(default=None, kw_only=True) + on_event_payload_batch_publish: Optional[Callable[[list[dict]], None]] = field(default=None, kw_only=True) def try_publish_event_payload(self, event_payload: dict) -> None: - if self.try_publish_event_payload_fn is not None: - self.try_publish_event_payload_fn(event_payload) + if self.on_event_payload_publish is not None: + self.on_event_payload_publish(event_payload) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: - if self.try_publish_event_payload_batch_fn is not None: - self.try_publish_event_payload_batch_fn(event_payload_batch) + if self.on_event_payload_batch_publish is not None: + self.on_event_payload_batch_publish(event_payload_batch) diff --git a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py index 29aecfdf9..36d4618b8 100644 --- a/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/audio_transcription/test_base_audio_transcription_driver.py @@ -14,7 +14,7 @@ def driver(self): def test_run_publish_events(self, driver, mock_config): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run( AudioArtifact( diff --git a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py index 9b7390d9c..36c8f3711 100644 --- a/tests/unit/drivers/event_listener/test_base_event_listener_driver.py +++ b/tests/unit/drivers/event_listener/test_base_event_listener_driver.py @@ -57,7 +57,7 @@ def test__safe_publish_event_payload(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=False, - try_publish_event_payload_fn=mock_fn, + on_event_payload_publish=mock_fn, ) mock_event_payload = MockEvent().to_dict() @@ -69,7 +69,7 @@ def test__safe_publish_event_payload_batch(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=True, - try_publish_event_payload_batch_fn=mock_fn, + on_event_payload_batch_publish=mock_fn, ) mock_event_payloads = [MockEvent().to_dict() for _ in range(0, 3)] @@ -81,7 +81,7 @@ def test__safe_publish_event_payload_error(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=False, - try_publish_event_payload_fn=mock_fn, + on_event_payload_publish=mock_fn, max_attempts=2, max_retry_delay=0.1, min_retry_delay=0.1, @@ -98,7 +98,7 @@ def test__safe_publish_event_payload_batch_error(self): mock_fn = MagicMock() driver = MockEventListenerDriver( batched=True, - try_publish_event_payload_batch_fn=mock_fn, + on_event_payload_batch_publish=mock_fn, max_attempts=2, max_retry_delay=0.1, min_retry_delay=0.1, diff --git a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py index 96b615a58..0545f6c83 100644 --- a/tests/unit/drivers/image_generation/test_base_image_generation_driver.py +++ b/tests/unit/drivers/image_generation/test_base_image_generation_driver.py @@ -15,7 +15,7 @@ def driver(self): def test_run_text_to_image_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run_text_to_image( ["foo", "bar"], @@ -31,7 +31,7 @@ def test_run_text_to_image_publish_events(self, driver): def test_run_image_variation_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run_image_variation( ["foo", "bar"], @@ -53,7 +53,7 @@ def test_run_image_variation_publish_events(self, driver): def test_run_image_image_inpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run_image_inpainting( ["foo", "bar"], @@ -81,7 +81,7 @@ def test_run_image_image_inpainting_publish_events(self, driver): def test_run_image_image_outpainting_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run_image_outpainting( ["foo", "bar"], diff --git a/tests/unit/drivers/image_query/test_base_image_query_driver.py b/tests/unit/drivers/image_query/test_base_image_query_driver.py index a77fb268e..652ee11c5 100644 --- a/tests/unit/drivers/image_query/test_base_image_query_driver.py +++ b/tests/unit/drivers/image_query/test_base_image_query_driver.py @@ -13,7 +13,7 @@ def driver(self): def test_query_publishes_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.query("foo", []) diff --git a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py index ac607afc3..e3c99f402 100644 --- a/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_hugging_face_pipeline_prompt_driver.py @@ -10,11 +10,11 @@ def mock_pipeline(self, mocker): return mocker.patch("transformers.pipeline") @pytest.fixture(autouse=True) - def mock_generator(self, mock_pipeline): - mock_generator = mock_pipeline.return_value - mock_generator.task = "text-generation" - mock_generator.return_value = [{"generated_text": [{"content": "model-output"}]}] - return mock_generator + def mock_provider(self, mock_pipeline): + mock_provider = mock_pipeline.return_value + mock_provider.task = "text-generation" + mock_provider.return_value = [{"generated_text": [{"content": "model-output"}]}] + return mock_provider @pytest.fixture(autouse=True) def mock_autotokenizer(self, mocker): @@ -70,10 +70,10 @@ def test_try_stream(self, prompt_stack): assert e.value.args[0] == "streaming is not supported" @pytest.mark.parametrize("choices", [[], [1, 2]]) - def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_generator, prompt_stack): + def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_provider, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.return_value = choices + mock_provider.return_value = choices # When with pytest.raises(Exception) as e: @@ -82,10 +82,10 @@ def test_try_run_throws_when_multiple_choices_returned(self, choices, mock_gener # Then assert e.value.args[0] == "completion with more than one choice is not supported yet" - def test_try_run_throws_when_non_list(self, mock_generator, prompt_stack): + def test_try_run_throws_when_non_list(self, mock_provider, prompt_stack): # Given driver = HuggingFacePipelinePromptDriver(model="foo", max_tokens=42) - mock_generator.return_value = {} + mock_provider.return_value = {} # When with pytest.raises(Exception) as e: diff --git a/tests/unit/drivers/sql/test_snowflake_sql_driver.py b/tests/unit/drivers/sql/test_snowflake_sql_driver.py index a758bb3a2..b13efbad3 100644 --- a/tests/unit/drivers/sql/test_snowflake_sql_driver.py +++ b/tests/unit/drivers/sql/test_snowflake_sql_driver.py @@ -63,32 +63,32 @@ def driver(self, mock_snowflake_engine, mock_snowflake_connection): def get_connection(): return mock_snowflake_connection - return SnowflakeSqlDriver(connection_func=get_connection, engine=mock_snowflake_engine) + return SnowflakeSqlDriver(get_connection=get_connection, engine=mock_snowflake_engine) - def test_connection_function_wrong_return_type(self): + def test_get_connectiontion_wrong_return_type(self): def get_connection() -> Any: return object with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(get_connection=get_connection) def test_connection_validation_no_schema(self, mock_snowflake_connection_no_schema): def get_connection(): return mock_snowflake_connection_no_schema with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(get_connection=get_connection) def test_connection_validation_no_database(self, mock_snowflake_connection_no_database): def get_connection(): return mock_snowflake_connection_no_database with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=get_connection) + SnowflakeSqlDriver(get_connection=get_connection) def test_engine_url_validation_wrong_engine(self, mock_snowflake_connection): with pytest.raises(ValueError): - SnowflakeSqlDriver(connection_func=mock_snowflake_connection, engine=create_engine("sqlite:///:memory:")) + SnowflakeSqlDriver(get_connection=mock_snowflake_connection, engine=create_engine("sqlite:///:memory:")) def test_execute_query(self, driver): assert driver.execute_query("query") == [ diff --git a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py index 2dd68e24e..4be4caf77 100644 --- a/tests/unit/drivers/structure_run/test_local_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_local_structure_run_driver.py @@ -9,7 +9,7 @@ class TestLocalStructureRunDriver: def test_run(self): pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: Agent()) + driver = LocalStructureRunDriver(create_structure=lambda: Agent()) task = StructureRunTask(driver=driver) @@ -22,7 +22,7 @@ def test_run_with_env(self, mock_config): mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output=lambda _: os.environ["KEY"]) agent = Agent() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent, env={"KEY": "value"}) + driver = LocalStructureRunDriver(create_structure=lambda: agent, env={"KEY": "value"}) task = StructureRunTask(driver=driver) pipeline.add_task(task) diff --git a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py index ab448c7c1..099fbc1ae 100644 --- a/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py +++ b/tests/unit/drivers/text_to_speech/test_base_audio_transcription_driver.py @@ -13,7 +13,7 @@ def driver(self): def test_text_to_audio_publish_events(self, driver): mock_handler = Mock() - EventBus.add_event_listener(EventListener(handler=mock_handler)) + EventBus.add_event_listener(EventListener(on_event=mock_handler)) driver.run_text_to_audio( ["foo", "bar"], diff --git a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py index 430f67ef9..e1d65457d 100644 --- a/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_footnote_prompt_response_rag_module.py @@ -15,7 +15,7 @@ def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): - system_message = module.default_system_template_generator( + system_message = module.default_generate_system_template( RagContext(query="test", before_query=["*RULESET*", "*META*"]), artifacts=[ TextArtifact("*TEXT SEGMENT 1*", reference=Reference(title="source 1")), diff --git a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py index cc8d35f0e..bf7f23fab 100644 --- a/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py +++ b/tests/unit/engines/rag/modules/generation/test_prompt_response_rag_module.py @@ -20,7 +20,7 @@ def test_run(self, module): assert module.run(RagContext(query="test")).value == "mock output" def test_prompt(self, module): - system_message = module.default_system_template_generator( + system_message = module.default_generate_system_template( RagContext(query="test"), artifacts=[TextArtifact("*TEXT SEGMENT 1*"), TextArtifact("*TEXT SEGMENT 2*")], ) diff --git a/tests/unit/events/test_event_bus.py b/tests/unit/events/test_event_bus.py index 165914756..7afe0b466 100644 --- a/tests/unit/events/test_event_bus.py +++ b/tests/unit/events/test_event_bus.py @@ -18,11 +18,11 @@ def test_add_event_listeners_same(self): assert len(EventBus.event_listeners) == 1 def test_add_event_listeners(self): - EventBus.add_event_listeners([EventListener(handler=lambda e: e), EventListener()]) + EventBus.add_event_listeners([EventListener(on_event=lambda e: e), EventListener()]) assert len(EventBus.event_listeners) == 2 def test_remove_event_listeners(self): - listeners = [EventListener(handler=lambda e: e), EventListener()] + listeners = [EventListener(on_event=lambda e: e), EventListener()] EventBus.add_event_listeners(listeners) EventBus.remove_event_listeners(listeners) assert len(EventBus.event_listeners) == 0 @@ -33,7 +33,7 @@ def test_add_event_listener_same(self): assert len(EventBus.event_listeners) == 1 def test_add_event_listener(self): - EventBus.add_event_listener(EventListener(handler=lambda e: e)) + EventBus.add_event_listener(EventListener(on_event=lambda e: e)) EventBus.add_event_listener(EventListener()) assert len(EventBus.event_listeners) == 2 @@ -52,7 +52,7 @@ def test_publish_event(self): # Given mock_handler = Mock() mock_handler.return_value = None - EventBus.add_event_listeners([EventListener(handler=mock_handler)]) + EventBus.add_event_listeners([EventListener(on_event=mock_handler)]) mock_event = MockEvent() # When diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index 8d0877f87..d26107bc6 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -42,7 +42,7 @@ def test_untyped_listeners(self, pipeline, mock_config): event_handler_1 = Mock() event_handler_2 = Mock() - EventBus.add_event_listeners([EventListener(handler=event_handler_1), EventListener(handler=event_handler_2)]) + EventBus.add_event_listeners([EventListener(on_event=event_handler_1), EventListener(on_event=event_handler_2)]) # can't mock subtask events, so must manually call pipeline.tasks[0].subtasks[0].before_run() diff --git a/tests/unit/mixins/test_futures_executor_mixin.py b/tests/unit/mixins/test_futures_executor_mixin.py index 3be336687..437903fe3 100644 --- a/tests/unit/mixins/test_futures_executor_mixin.py +++ b/tests/unit/mixins/test_futures_executor_mixin.py @@ -7,4 +7,4 @@ class TestFuturesExecutorMixin: def test_futures_executor(self): executor = futures.ThreadPoolExecutor() - assert MockFuturesExecutor(futures_executor_fn=lambda: executor).futures_executor == executor + assert MockFuturesExecutor(create_futures_executor=lambda: executor).futures_executor == executor diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index f86c6330a..6452ad5e4 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -18,14 +18,14 @@ def fn(task): time.sleep(2) return TextArtifact("done") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(on_run=fn) @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(on_run=fn) def test_init(self): pipeline = Pipeline(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index d3fb17906..9cd34291f 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -17,14 +17,14 @@ def fn(task): time.sleep(2) return TextArtifact("done") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(on_run=fn) @pytest.fixture() def error_artifact_task(self): def fn(task): return ErrorArtifact("error") - return CodeExecutionTask(run_fn=fn) + return CodeExecutionTask(on_run=fn) def test_init(self): workflow = Workflow(rulesets=[Ruleset("TestRuleset", [Rule("test")])]) diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index 85addfe3d..4c5697621 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -14,11 +14,11 @@ class TestBaseTask: @pytest.fixture() def task(self): - EventBus.add_event_listeners([EventListener(handler=Mock())]) + EventBus.add_event_listeners([EventListener(on_event=Mock())]) agent = Agent( tools=[MockTool()], ) - EventBus.add_event_listeners([EventListener(handler=Mock())]) + EventBus.add_event_listeners([EventListener(on_event=Mock())]) agent.add_task(MockTask("foobar", max_meta_memory_entries=2)) @@ -114,7 +114,7 @@ def test_children_property_no_structure(self, task): def test_run_publish_events(self, task): task.run() - assert EventBus.event_listeners[0].handler.call_count == 2 + assert EventBus.event_listeners[0].on_event.call_count == 2 def test_add_parent(self, task): agent = Agent() diff --git a/tests/unit/tasks/test_code_execution_task.py b/tests/unit/tasks/test_code_execution_task.py index 8e69f53a3..436e8ba87 100644 --- a/tests/unit/tasks/test_code_execution_task.py +++ b/tests/unit/tasks/test_code_execution_task.py @@ -21,7 +21,7 @@ def deliberate_exception(task: CodeExecutionTask) -> BaseArtifact: class TestCodeExecutionTask: def test_hello_world_fn(self): - task = CodeExecutionTask(run_fn=hello_world) + task = CodeExecutionTask(on_run=hello_world) assert task.try_run().value == "Hello World!" @@ -29,13 +29,13 @@ def test_hello_world_fn(self): # Overriding the input because we are implementing the task not the Pipeline def test_noop_fn(self): pipeline = Pipeline() - task = CodeExecutionTask("No Op", run_fn=non_outputting) + task = CodeExecutionTask("No Op", on_run=non_outputting) pipeline.add_task(task) temp = task.try_run() assert temp.value == "No Op" def test_error_fn(self): - task = CodeExecutionTask(run_fn=deliberate_exception) + task = CodeExecutionTask(on_run=deliberate_exception) with pytest.raises(ValueError): task.try_run() diff --git a/tests/unit/tasks/test_structure_run_task.py b/tests/unit/tasks/test_structure_run_task.py index 2973c4a05..1df8ca8bf 100644 --- a/tests/unit/tasks/test_structure_run_task.py +++ b/tests/unit/tasks/test_structure_run_task.py @@ -10,7 +10,7 @@ def test_run_single_input(self, mock_config): agent = Agent() mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + driver = LocalStructureRunDriver(create_structure=lambda: agent) task = StructureRunTask(driver=driver) @@ -23,7 +23,7 @@ def test_run_multiple_inputs(self, mock_config): agent = Agent() mock_config.drivers_config.prompt_driver = MockPromptDriver(mock_output="pipeline mock output") pipeline = Pipeline() - driver = LocalStructureRunDriver(structure_factory_fn=lambda: agent) + driver = LocalStructureRunDriver(create_structure=lambda: agent) task = StructureRunTask(input=["foo", "bar", "baz"], driver=driver) diff --git a/tests/unit/tools/test_structure_run_tool.py b/tests/unit/tools/test_structure_run_tool.py index f62cdeea7..8b581103e 100644 --- a/tests/unit/tools/test_structure_run_tool.py +++ b/tests/unit/tools/test_structure_run_tool.py @@ -10,9 +10,7 @@ class TestStructureRunTool: def client(self): agent = Agent() - return StructureRunTool( - description="foo bar", driver=LocalStructureRunDriver(structure_factory_fn=lambda: agent) - ) + return StructureRunTool(description="foo bar", driver=LocalStructureRunDriver(create_structure=lambda: agent)) def test_run_structure(self, client): assert client.run_structure({"values": {"args": "foo bar"}}).value == "mock output" diff --git a/tests/unit/tools/test_vector_store_tool.py b/tests/unit/tools/test_vector_store_tool.py index 30596f09f..a8896c757 100644 --- a/tests/unit/tools/test_vector_store_tool.py +++ b/tests/unit/tools/test_vector_store_tool.py @@ -23,12 +23,12 @@ def test_search_with_namespace(self): assert len(tool1.search({"values": {"query": "test"}})) == 2 assert len(tool2.search({"values": {"query": "test"}})) == 0 - def test_custom_process_query_output_fn(self): + def test_custom_process_query_output(self): driver = LocalVectorStoreDriver(embedding_driver=MockEmbeddingDriver()) tool1 = VectorStoreTool( description="Test", vector_store_driver=driver, - process_query_output_fn=lambda es: ListArtifact([e.vector for e in es]), + process_query_output=lambda es: ListArtifact([e.vector for e in es]), query_params={"include_vectors": True}, ) diff --git a/tests/unit/utils/test_chat.py b/tests/unit/utils/test_chat.py index 5a7b4e069..4cb43e05e 100644 --- a/tests/unit/utils/test_chat.py +++ b/tests/unit/utils/test_chat.py @@ -21,8 +21,8 @@ def test_init(self): intro_text="hello...", prompt_prefix="Question: ", response_prefix="Answer: ", - input_fn=input, - output_fn=logging.info, + handle_input=input, + handle_output=logging.info, logger_level=logging.INFO, ) assert chat.structure == agent @@ -31,8 +31,8 @@ def test_init(self): assert chat.intro_text == "hello..." assert chat.prompt_prefix == "Question: " assert chat.response_prefix == "Answer: " - assert callable(chat.input_fn) - assert callable(chat.output_fn) + assert callable(chat.handle_input) + assert callable(chat.handle_output) assert chat.logger_level == logging.INFO @patch("builtins.input", side_effect=["exit"]) @@ -57,16 +57,16 @@ def test_chat_prompt(self): @pytest.mark.parametrize("stream", [True, False]) @patch("builtins.input", side_effect=["foo", "exit"]) def test_start(self, mock_input, stream): - mock_output_fn = Mock() + mock_handle_output = Mock() agent = Agent(conversation_memory=ConversationMemory(), stream=stream) - chat = Chat(agent, intro_text="foo", output_fn=mock_output_fn) + chat = Chat(agent, intro_text="foo", handle_output=mock_handle_output) chat.start() mock_input.assert_has_calls([call(), call()]) if stream: - mock_output_fn.assert_has_calls( + mock_handle_output.assert_has_calls( [ call("foo"), call("Thinking..."), @@ -76,7 +76,7 @@ def test_start(self, mock_input, stream): ] ) else: - mock_output_fn.assert_has_calls( + mock_handle_output.assert_has_calls( [ call("foo"), call("Thinking..."), diff --git a/tests/utils/structure_tester.py b/tests/utils/structure_tester.py index c943525b6..a34871013 100644 --- a/tests/utils/structure_tester.py +++ b/tests/utils/structure_tester.py @@ -224,7 +224,7 @@ class TesterPromptDriverOption: structure: Structure = field() @classmethod - def prompt_driver_id_fn(cls, prompt_driver) -> str: + def generate_prompt_driver_id(cls, prompt_driver) -> str: return f"{prompt_driver.__class__.__name__}-{prompt_driver.model}" def verify_structure_output(self, structure) -> dict: