Skip to content

Commit

Permalink
Merge pull request #347 from ag2ai/o1-update
Browse files Browse the repository at this point in the history
OpenAI o1 support
  • Loading branch information
davorrunje authored Jan 15, 2025
2 parents 5f20079 + fe61a34 commit ca25482
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 45 deletions.
15 changes: 0 additions & 15 deletions .github/workflows/docs-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,9 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: "3.9"
# - name: Install packages and dependencies for all tests
# run: |
# uv pip install --system pytest-cov>=5
- name: Install packages
run: |
uv pip install --system -e ".[test,docs]"
uv pip list
# - name: Install packages and dependencies for Documentation
# run: |
# uv pip install --system pydoc-markdown pyyaml termcolor nbclient
# # Pin databind packages as version 4.5.0 is not compatible with pydoc-markdown.
# uv pip install --system databind.core==4.4.2 databind.json==4.4.2
# Force reinstall specific versions to fix typing-extensions import error in CI
# - name: Force install specific versions of typing-extensions and pydantic
# run: |
# uv pip uninstall --system -y typing_extensions typing-extensions || true
# uv pip install --system --force-reinstall "typing-extensions==4.7.1"
# uv pip install --system --force-reinstall "pydantic<2.0"
- name: Run documentation tests
run: |
bash scripts/test.sh test/website/test_process_api_reference.py test/website/test_process_notebooks.py -m "not openai"
Expand Down
13 changes: 13 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@ class UndefinedNextAgent(Exception): # noqa: N818
def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)


class ModelToolNotSupportedError(Exception):
"""
Exception raised when attempting to use tools with models that do not support them.
"""

def __init__(
self,
model: str,
):
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
super().__init__(self.message)
75 changes: 66 additions & 9 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@
import logging
import sys
import uuid
import warnings
from typing import Any, Callable, Optional, Protocol, Union

from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.io.base import IOStream
from autogen.logger.logger_utils import get_current_ts
from autogen.oai.client_utils import FormatterProtocol, logging_formatter
from autogen.oai.openai_utils import OAI_PRICE1K, get_key, is_valid_api_key
from autogen.runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
from autogen.token_count_utils import count_token

from ..cache import Cache
from ..exception_utils import ModelToolNotSupportedError
from ..io.base import IOStream
from ..logger.logger_utils import get_current_ts
from ..messages.client_messages import StreamMessage, UsageSummaryMessage
from ..runtime_logging import log_chat_completion, log_new_client, log_new_wrapper, logging_enabled
from ..token_count_utils import count_token
from .client_utils import FormatterProtocol, logging_formatter
from .openai_utils import OAI_PRICE1K, get_key, is_valid_api_key

TOOL_ENABLED = False
try:
Expand Down Expand Up @@ -302,8 +303,11 @@ def _create_or_parse(*args, **kwargs):
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create

# needs to be updated when the o3 is released to generalize
is_o1 = "model" in params and params["model"].startswith("o1")

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
if params.get("stream", False) and "messages" in params and not is_o1:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
Expand Down Expand Up @@ -410,11 +414,64 @@ def _create_or_parse(*args, **kwargs):
else:
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
if is_o1:
# add a warning that model does not support stream
if params.get("stream", False):
warnings.warn(
f"The {params.get('model')} model does not support streaming. The stream will be set to False."
)
if params.get("tools", False):
raise ModelToolNotSupportedError(params.get("model"))
self._process_reasoning_model_params(params)
params["stream"] = False
response = create_or_parse(**params)
# remove the system_message from the response and add it in the prompt at the start.
if is_o1:
for msg in params["messages"]:
if msg["role"] == "user" and msg["content"].startswith("System message: "):
msg["role"] = "system"
msg["content"] = msg["content"][len("System message: ") :]

return response

def _process_reasoning_model_params(self, params) -> None:
"""
Cater for the reasoning model (o1, o3..) parameters
please refer: https://platform.openai.com/docs/guides/reasoning#limitations
"""
print(f"{params=}")

# Unsupported parameters
unsupported_params = [
"temperature",
"frequency_penalty",
"presence_penalty",
"top_p",
"logprobs",
"top_logprobs",
"logit_bias",
]
model_name = params.get("model")
for param in unsupported_params:
if param in params:
warnings.warn(f"`{param}` is not supported with {model_name} model and will be ignored.")
params.pop(param)
# Replace max_tokens with max_completion_tokens as reasoning tokens are now factored in
# and max_tokens isn't valid
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")

# TODO - When o1-mini and o1-preview point to newer models (e.g. 2024-12-...), remove them from this list but leave the 2024-09-12 dated versions
system_not_allowed = model_name in ("o1-mini", "o1-preview", "o1-mini-2024-09-12", "o1-preview-2024-09-12")

if "messages" in params and system_not_allowed:
# o1-mini (2024-09-12) and o1-preview (2024-09-12) don't support role='system' messages, only 'user' and 'assistant'
# replace the system messages with user messages preappended with "System message: "
for msg in params["messages"]:
if msg["role"] == "system":
msg["role"] = "user"
msg["content"] = f"System message: {msg['content']}"

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
Expand Down
7 changes: 7 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# o1
"o1-preview-2024-09-12": (0.0015, 0.0060),
"o1-preview": (0.0015, 0.0060),
"o1-mini-2024-09-12": (0.0003, 0.0012),
"o1-mini": (0.0003, 0.0012),
"o1": (0.0015, 0.0060),
"o1-2024-12-17": (0.0015, 0.0060),
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ dependencies = [
"python-dotenv",
"tiktoken",
# Disallowing 2.6.0 can be removed when this is fixed https://github.com/pydantic/pydantic/issues/8705
"pydantic>=1.10,<3,!=2.6.0", # could be both V1 and V2
"pydantic>=2.6.1,<3",
"docker",
"packaging",
"websockets>=14,<15",
Expand Down
66 changes: 50 additions & 16 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
# Portions derived from https://github.com/microsoft/autogen are under the MIT License.
# SPDX-License-Identifier: MIT
import os
from pathlib import Path
from typing import Any, Optional

Expand Down Expand Up @@ -65,14 +66,21 @@ def openai_api_key(self) -> str:
return self.llm_config["config_list"][0]["api_key"] # type: ignore[no-any-return]


def get_credentials(filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0) -> Credentials:
def get_credentials(
filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0, fail_if_empty: bool = True
) -> Credentials:
"""Fixture to load the LLM config."""
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict=filter_dict,
file_location=KEY_LOC,
)
assert config_list, "No config list found"
try:
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict=filter_dict,
file_location=KEY_LOC,
)
except Exception:
config_list = []

if fail_if_empty:
assert config_list, "No config list found"

return Credentials(
llm_config={
Expand All @@ -82,12 +90,26 @@ def get_credentials(filter_dict: Optional[dict[str, Any]] = None, temperature: f
)


def get_openai_credentials(filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0) -> Credentials:
config_list = [
conf
for conf in get_credentials(filter_dict, temperature).config_list
if "api_type" not in conf or conf["api_type"] == "openai"
]
def get_openai_config_list_from_env(
model: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
) -> Credentials:
if "OPENAI_API_KEY" in os.environ:
api_key = os.environ["OPENAI_API_KEY"]
return [{"api_key": api_key, "model": model, **filter_dict}]


def get_openai_credentials(
model: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
) -> Credentials:
config_list = get_credentials(filter_dict, temperature, fail_if_empty=False).config_list

# Filter out non-OpenAI configs
config_list = [conf for conf in config_list if "api_type" not in conf or conf["api_type"] == "openai"]

# If no OpenAI config found, try to get it from the environment
if config_list == []:
config_list = get_openai_config_list_from_env(model, filter_dict, temperature)

assert config_list, "No OpenAI config list found"

return Credentials(
Expand Down Expand Up @@ -122,17 +144,29 @@ def credentials_all() -> Credentials:

@pytest.fixture
def credentials_gpt_4o_mini() -> Credentials:
return get_openai_credentials(filter_dict={"tags": ["gpt-4o-mini"]})
return get_openai_credentials(model="gpt-4o-mini", filter_dict={"tags": ["gpt-4o-mini"]})


@pytest.fixture
def credentials_gpt_4o() -> Credentials:
return get_openai_credentials(filter_dict={"tags": ["gpt-4o"]})
return get_openai_credentials(model="gpt-4o", filter_dict={"tags": ["gpt-4o"]})


@pytest.fixture
def credentials_o1_mini() -> Credentials:
return get_openai_credentials(model="o1-mini", filter_dict={"tags": ["o1-mini"]})


@pytest.fixture
def credentials_o1() -> Credentials:
return get_openai_credentials(model="o1", filter_dict={"tags": ["o1"]})


@pytest.fixture
def credentials_gpt_4o_realtime() -> Credentials:
return get_openai_credentials(filter_dict={"tags": ["gpt-4o-realtime"]}, temperature=0.6)
return get_openai_credentials(
model="gpt-4o-realtime-preview", filter_dict={"tags": ["gpt-4o-realtime"]}, temperature=0.6
)


@pytest.fixture
Expand Down
Loading

0 comments on commit ca25482

Please sign in to comment.