Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Structured Output #1241

Merged
merged 1 commit into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `LocalRulesetDriver` for loading a `Ruleset` from a local `.json` file.
- `GriptapeCloudRulesetDriver` for loading a `Ruleset` resource from Griptape Cloud.
- Parameter `alias` on `GriptapeCloudConversationMemoryDriver` for fetching a Thread by alias.
- Basic support for OpenAi Structured Output via `OpenAiChatPromptDriver.response_format` parameter.

### Changed
- **BREAKING**: Renamed parameters on several classes to `client`:
Expand Down Expand Up @@ -61,6 +62,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- **BREAKING**: `CsvExtractionEngine.column_names` is now required.
- **BREAKING**: Renamed`RuleMixin.all_rulesets` to `RuleMixin.rulesets`.
- **BREAKING**: Renamed `GriptapeCloudKnowledgeBaseVectorStoreDriver` to `GriptapeCloudVectorStoreDriver`.
- **BREAKING**: `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`.
- `MarkdownifyWebScraperDriver.DEFAULT_EXCLUDE_TAGS` now includes media/blob-like HTML tags
- `StructureRunTask` now inherits from `PromptTask`.
- Several places where API clients are initialized are now lazy loaded.
Expand Down
18 changes: 18 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ from griptape.drivers.griptape_cloud_vector_store_driver import GriptapeCloudVec
driver = GriptapeCloudVectorStoreDriver(...)
```

### `OpenAiChatPromptDriver.response_format` is now a `dict` instead of a `str`.

`OpenAiChatPromptDriver.response_format` is now structured as the `openai` SDK accepts it.

#### Before
```python
driver = OpenAiChatPromptDriver(
response_format="json_object"
)
```

#### After
```python
driver = OpenAiChatPromptDriver(
response_format={"type": "json_object"}
)
```

## 0.31.X to 0.32.X

### Removed `DataframeLoader`
Expand Down
17 changes: 12 additions & 5 deletions docs/griptape-framework/drivers/src/prompt_drivers_3.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import os

import schema

from griptape.drivers import OpenAiChatPromptDriver
from griptape.rules import Rule
from griptape.structures import Agent

agent = Agent(
prompt_driver=OpenAiChatPromptDriver(
api_key=os.environ["OPENAI_API_KEY"],
model="gpt-4o-2024-08-06",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include the temperature/seed from past example (if still relevant)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sure, I added them back

temperature=0.1,
model="gpt-4o",
response_format="json_object",
seed=42,
response_format={
"type": "json_schema",
"json_schema": {
"strict": True,
"name": "Output",
"schema": schema.Schema({"css_code": str, "relevant_emojies": [str]}).json_schema("Output Schema"),
},
},
),
input="You will be provided with a description of a mood, and your task is to generate the CSS code for a color that matches it. Description: {{ args[0] }}",
rules=[Rule(value='Write your output in json with a single key called "css_code".')],
input="You will be provided with a description of a mood, and your task is to generate the CSS color code for a color that matches it. Description: {{ args[0] }}",
)

agent.run("Blue sky at dusk.")
15 changes: 9 additions & 6 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Optional

import openai
from attrs import Factory, define, field
Expand Down Expand Up @@ -62,7 +62,7 @@ class OpenAiChatPromptDriver(BasePromptDriver):
kw_only=True,
)
user: str = field(default="", kw_only=True, metadata={"serializable": True})
response_format: Optional[Literal["json_object"]] = field(
response_format: Optional[dict] = field(
default=None,
kw_only=True,
metadata={"serializable": True},
Expand Down Expand Up @@ -145,10 +145,13 @@ def _base_params(self, prompt_stack: PromptStack) -> dict:
**({"stream_options": {"include_usage": True}} if self.stream else {}),
}

if self.response_format == "json_object":
params["response_format"] = {"type": "json_object"}
# JSON mode still requires a system message instructing the LLM to output JSON.
prompt_stack.add_system_message("Provide your response as a valid JSON object.")
if self.response_format is not None:
if self.response_format == {"type": "json_object"}:
params["response_format"] = self.response_format
# JSON mode still requires a system message instructing the LLM to output JSON.
prompt_stack.add_system_message("Provide your response as a valid JSON object.")
else:
params["response_format"] = self.response_format

messages = self.__to_openai_messages(prompt_stack.messages)

Expand Down
52 changes: 50 additions & 2 deletions tests/unit/drivers/prompt/test_openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import Mock

import pytest
import schema

from griptape.artifacts import ActionArtifact, ImageArtifact, ListArtifact, TextArtifact
from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction
Expand Down Expand Up @@ -343,10 +344,12 @@ def test_try_run(self, mock_chat_completion_create, prompt_stack, messages, use_
assert message.value[1].value.path == "test"
assert message.value[1].value.input == {"foo": "bar"}

def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack, messages):
def test_try_run_response_format_json_object(self, mock_chat_completion_create, prompt_stack, messages):
# Given
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL, response_format="json_object", use_native_tools=False
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
response_format={"type": "json_object"},
use_native_tools=False,
)

# When
Expand All @@ -365,6 +368,51 @@ def test_try_run_response_format(self, mock_chat_completion_create, prompt_stack
assert message.usage.input_tokens == 5
assert message.usage.output_tokens == 10

def test_try_run_response_format_json_schema(self, mock_chat_completion_create, prompt_stack, messages):
# Given
driver = OpenAiChatPromptDriver(
model=OpenAiTokenizer.DEFAULT_OPENAI_GPT_3_CHAT_MODEL,
response_format={
"type": "json_schema",
"json_schema": {
"strict": True,
"name": "OutputSchema",
"schema": schema.Schema({"test": str}).json_schema("Output Schema"),
},
},
use_native_tools=False,
)

# When
message = driver.try_run(prompt_stack)

# Then
mock_chat_completion_create.assert_called_once_with(
model=driver.model,
temperature=driver.temperature,
user=driver.user,
messages=[*messages],
seed=driver.seed,
response_format={
"json_schema": {
"schema": {
"$id": "Output Schema",
"$schema": "http://json-schema.org/draft-07/schema#",
"additionalProperties": False,
"properties": {"test": {"type": "string"}},
"required": ["test"],
"type": "object",
},
"name": "OutputSchema",
"strict": True,
},
"type": "json_schema",
},
)
assert message.value[0].value == "model-output"
assert message.usage.input_tokens == 5
assert message.usage.output_tokens == 10

@pytest.mark.parametrize("use_native_tools", [True, False])
def test_try_stream_run(self, mock_chat_completion_stream_create, prompt_stack, messages, use_native_tools):
# Given
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/structure_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def verify_structure_output(self, structure) -> dict:
model="gpt-4o",
azure_deployment=os.environ["AZURE_OPENAI_4_DEPLOYMENT_ID"],
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT_1"],
response_format="json_object",
response_format={"type": "json_object"},
)
output_schema = Schema(
{
Expand Down
Loading