Skip to content

Commit

Permalink
New streaming events, EventListener listen on parent types (#1266)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored Oct 21, 2024
1 parent e6ffc90 commit dab865a
Show file tree
Hide file tree
Showing 22 changed files with 334 additions and 60 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.input_fn` for customizing the input to the Chat utility.
- `GriptapeCloudFileManagerDriver` for managing files on Griptape Cloud.
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.
- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`.

### Changed

- **BREAKING**: Removed `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead.
- **BREAKING**: Renamed parameter `driver` on `EventListener` to `event_listener_driver`.
- **BREAKING**: Changed default value of parameter `handler` on `EventListener` to `None`.
- **BREAKING**: Updated `EventListener.handler` return value behavior.
- **BREAKING**: Removed `CompletionChunkEvent`.
- If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`.
- If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
Expand All @@ -42,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.
- `EventListener.event_types` will now listen on child types of any provided type.

### Fixed

Expand Down
40 changes: 40 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,46 @@ This document provides instructions for migrating your codebase to accommodate b

## 0.33.X to 0.34.X

### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.

#### Before

```python
def handler_fn_stream(event: CompletionChunkEvent) -> None:
print(f"CompletionChunkEvent: {event.to_json()}")

def handler_fn_stream_text(event: CompletionChunkEvent) -> None:
# This prints out Tool actions with no easy way
# to filter them out
print(event.token, end="", flush=True)

EventListener(handler=handler_fn_stream, event_types=[CompletionChunkEvent])
EventListener(handler=handler_fn_stream_text, event_types=[CompletionChunkEvent])
```

#### After

```python
def handler_fn_stream(event: BaseChunkEvent) -> None:
print(str(e), end="", flush=True)
# print out each child event type
if isinstance(event, TextChunkEvent):
print(f"TextChunkEvent: {event.to_json()}")
if isinstance(event, ActionChunkEvent):
print(f"ActionChunkEvent: {event.to_json()}")


def handler_fn_stream_text(event: TextChunkEvent) -> None:
# This will only be text coming from the
# prompt driver, not Tool actions
print(event.token, end="", flush=True)

EventListener(handler=handler_fn_stream, event_types=[BaseChunkEvent])
EventListener(handler=handler_fn_stream_text, event_types=[TextChunkEvent])
```

### `EventListener.handler` behavior, `driver` parameter rename

Returning `None` from the `handler` function now causes the event to not be published to the `EventListenerDriver`.
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ test: test/unit test/integration
test/unit: ## Run unit tests.
@poetry run pytest -n auto tests/unit

.PHONY: test/unit/%
test/unit/%: ## Run specific unit tests.
@poetry run pytest -n auto tests/unit -k $*

.PHONY: test/unit/coverage
test/unit/coverage:
@poetry run pytest -n auto --cov=griptape tests/unit
Expand Down
12 changes: 9 additions & 3 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,20 @@ The `EventListener` will automatically be added and removed from the [EventBus](

## Streaming

You can use the [CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md) to stream the completion results from Prompt Drivers.
You can use the [BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md) to stream the completion results from Prompt Drivers.

```python
--8<-- "docs/griptape-framework/misc/src/events_3.py"
```

You can also use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap
[CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md)s in a Python iterator.
You can also use the [TextChunkEvent](../../reference/griptape/events/text_chunk_event.md) and [ActionChunkEvent](../../reference/griptape/events/action_chunk_event.md) to further differentiate the different types of chunks for more customized output.

```python
--8<-- "docs/griptape-framework/misc/src/events_chunk_stream.py"
```

If you want Griptape to handle the chunk events for you, use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap
[BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md)s in a Python iterator.

```python
--8<-- "docs/griptape-framework/misc/src/events_4.py"
Expand Down
10 changes: 4 additions & 6 deletions docs/griptape-framework/misc/src/events_3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import cast

from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import CompletionChunkEvent, EventBus, EventListener
from griptape.events import BaseChunkEvent, EventBus, EventListener
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True),
event_types=[CompletionChunkEvent],
)
lambda e: print(str(e), end="", flush=True),
event_types=[BaseChunkEvent],
),
]
)

Expand Down
29 changes: 29 additions & 0 deletions docs/griptape-framework/misc/src/events_chunk_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import ActionChunkEvent, EventBus, EventListener, TextChunkEvent
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(str(e), end="", flush=True),
event_types=[TextChunkEvent],
),
EventListener(
lambda e: print(str(e), end="", flush=True),
event_types=[ActionChunkEvent],
),
]
)

pipeline = Pipeline()
pipeline.add_tasks(
ToolkitTask(
"Based on https://griptape.ai, tell me what griptape is.",
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True),
tools=[WebScraperTool(off_prompt=True), PromptSummaryTool(off_prompt=False)],
)
)

pipeline.run()
23 changes: 17 additions & 6 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@
TextMessageContent,
observable,
)
from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent
from griptape.events import (
ActionChunkEvent,
EventBus,
FinishPromptEvent,
StartPromptEvent,
TextChunkEvent,
)
from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin
from griptape.mixins.serializable_mixin import SerializableMixin

Expand Down Expand Up @@ -127,12 +133,17 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
else:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
EventBus.publish_event(CompletionChunkEvent(token=content.text))
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
elif isinstance(content, ActionCallDeltaMessageContent):
if content.tag is not None and content.name is not None and content.path is not None:
EventBus.publish_event(CompletionChunkEvent(token=str(content)))
elif content.partial_input is not None:
EventBus.publish_event(CompletionChunkEvent(token=content.partial_input))
EventBus.publish_event(
ActionChunkEvent(
partial_input=content.partial_input,
tag=content.tag,
name=content.name,
path=content.path,
index=content.index,
),
)

# Build a complete content from the content deltas
return self.__build_message(list(delta_contents.values()), usage)
Expand Down
8 changes: 6 additions & 2 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from .finish_prompt_event import FinishPromptEvent
from .start_structure_run_event import StartStructureRunEvent
from .finish_structure_run_event import FinishStructureRunEvent
from .completion_chunk_event import CompletionChunkEvent
from .base_chunk_event import BaseChunkEvent
from .text_chunk_event import TextChunkEvent
from .action_chunk_event import ActionChunkEvent
from .event_listener import EventListener
from .start_image_generation_event import StartImageGenerationEvent
from .finish_image_generation_event import FinishImageGenerationEvent
Expand All @@ -37,7 +39,9 @@
"FinishPromptEvent",
"StartStructureRunEvent",
"FinishStructureRunEvent",
"CompletionChunkEvent",
"BaseChunkEvent",
"TextChunkEvent",
"ActionChunkEvent",
"EventListener",
"StartImageGenerationEvent",
"FinishImageGenerationEvent",
Expand Down
33 changes: 33 additions & 0 deletions griptape/events/action_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent


@define
class ActionChunkEvent(BaseChunkEvent):
partial_input: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

def __str__(self) -> str:
parts = []

if self.name:
parts.append(self.name)
if self.path:
parts.append(f".{self.path}")
if self.tag:
parts.append(f" ({self.tag})")

if self.partial_input:
if parts:
parts.append(f"\n{self.partial_input}")
else:
parts.append(self.partial_input)

return "".join(parts)
13 changes: 13 additions & 0 deletions griptape/events/base_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from abc import abstractmethod

from attrs import define, field

from griptape.events.base_event import BaseEvent


@define
class BaseChunkEvent(BaseEvent):
index: int = field(default=0, metadata={"serializable": True})

@abstractmethod
def __str__(self) -> str: ...
8 changes: 0 additions & 8 deletions griptape/events/completion_chunk_event.py

This file was deleted.

2 changes: 1 addition & 1 deletion griptape/events/event_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002
def publish_event(self, event: T, *, flush: bool = False) -> None:
event_types = self.event_types

if event_types is None or type(event) in event_types:
if event_types is None or any(isinstance(event, event_type) for event_type in event_types):
handled_event = event
if self.handler is not None:
handled_event = self.handler(event)
Expand Down
11 changes: 11 additions & 0 deletions griptape/events/text_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent


@define
class TextChunkEvent(BaseChunkEvent):
token: str = field(kw_only=True, metadata={"serializable": True})

def __str__(self) -> str:
return self.token
28 changes: 24 additions & 4 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from __future__ import annotations

import json
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING

from attrs import Attribute, Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent
from griptape.events import (
ActionChunkEvent,
BaseChunkEvent,
EventBus,
EventListener,
FinishPromptEvent,
FinishStructureRunEvent,
TextChunkEvent,
)

if TYPE_CHECKING:
from collections.abc import Iterator
Expand All @@ -18,7 +27,7 @@

@define
class Stream:
"""A wrapper for Structures that converts `CompletionChunkEvent`s into an iterator of TextArtifacts.
"""A wrapper for Structures that converts `BaseChunkEvent`s into an iterator of TextArtifacts.
It achieves this by running the Structure in a separate thread, listening for events from the Structure,
and yielding those events.
Expand Down Expand Up @@ -48,14 +57,25 @@ def run(self, *args) -> Iterator[TextArtifact]:
t = Thread(target=self._run_structure, args=args)
t.start()

action_str = ""
while True:
event = self._event_queue.get()
if isinstance(event, FinishStructureRunEvent):
break
elif isinstance(event, FinishPromptEvent):
yield TextArtifact(value="\n")
elif isinstance(event, CompletionChunkEvent):
elif isinstance(event, TextChunkEvent):
yield TextArtifact(value=event.token)
elif isinstance(event, ActionChunkEvent):
if event.tag is not None and event.name is not None and event.path is not None:
yield TextArtifact(f"{event.name}.{event.tag} ({event.path})")
if event.partial_input is not None:
action_str += event.partial_input
try:
yield TextArtifact(json.dumps(json.loads(action_str), indent=2))
action_str = ""
except Exception:
pass
t.join()

def _run_structure(self, *args) -> None:
Expand All @@ -64,7 +84,7 @@ def event_handler(event: BaseEvent) -> None:

stream_event_listener = EventListener(
handler=event_handler,
event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent],
)
EventBus.add_event_listener(stream_event_listener)

Expand Down
11 changes: 11 additions & 0 deletions tests/mocks/mock_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent


@define
class MockChunkEvent(BaseChunkEvent):
token: str = field(kw_only=True, metadata={"serializable": True})

def __str__(self) -> str:
return "mock " + self.token
Loading

0 comments on commit dab865a

Please sign in to comment.