Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI o1 model support #43

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion:
iostream = IOStream.get_default()

completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
params = self.map_params(params.copy())

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
response_contents = [""] * params.get("n", 1)
Expand Down Expand Up @@ -408,6 +410,36 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return]
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator]

def map_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
"""Maps parameters that are deprecated"""

# max_tokens is deprecated and replaced by max_completion_tokens as of 2024.09.12
if "max_tokens" in params:
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
params["max_completion_tokens"] = params.pop("max_tokens")
logger.warning("OpenAI API: 'max_tokens' parameter is deprecated, converting to 'max_completion_tokens'.")

if params["model"].startswith("o1"):
sonichi marked this conversation as resolved.
Show resolved Hide resolved
# Beta limitation - remove streaming, convert system messages to user, remove other parameters which have fixed values
# https://platform.openai.com/docs/guides/reasoning/beta-limitations
if "stream" in params:
if params["stream"]:
logger.warning("OpenAI API o1 beta limitation: streaming is not supported.")
params.pop("stream")

for message in params["messages"]:
if message["role"] == "system":
message["role"] = "user"
marklysze marked this conversation as resolved.
Show resolved Hide resolved

fixed_params = ["temperature", "top_p", "n", "presence_penalty", "frequency_penalty"]
for param_name in fixed_params:
if param_name in params:
logger.warning(
f"OpenAI API o1 beta limitation: {param_name} parameter has a fixed value, removing."
)
params.pop(param_name)

return params

@staticmethod
def get_usage(response: Union[ChatCompletion, Completion]) -> Dict:
return {
Expand Down
6 changes: 6 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# o1
"o1-preview": (0.015, 0.06),
"o1-preview-2024-09-12": (0.015, 0.06),
# o1-mini
"o1-mini": (0.003, 0.012),
"o1-mini-2024-09-12": (0.003, 0.012),
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
Expand Down
38 changes: 13 additions & 25 deletions autogen/token_count_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def get_max_token_limit(model: str = "gpt-3.5-turbo-0613") -> int:
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 128000,
"gpt-4o-mini-2024-07-18": 128000,
"o1-preview-2024-09-12": 128000,
"o1-preview": 128000,
"o1-mini-2024-09-12": 128000,
"o1-mini": 128000,
}
return max_token_limit[model]

Expand Down Expand Up @@ -106,34 +110,18 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
except KeyError:
logger.warning(f"Model {model} not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
if "gpt-3" in model or "gpt-4" in model or model.startswith("o1"):
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
logger.info("gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return _num_token_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
logger.info("gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
tokens_per_name = 1 # OpenAI guidance is 1 extra token if 'name' field is used
elif "gemini" in model:
logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
logger.info("Gemini is not supported in tiktoken. Returning num tokens assuming gpt-4.")
marklysze marked this conversation as resolved.
Show resolved Hide resolved
return _num_token_from_messages(messages, model="gpt-4")
elif "claude" in model:
logger.info("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
logger.info("Claude is not supported in tiktoken. Returning num tokens assuming gpt-4.")
return _num_token_from_messages(messages, model="gpt-4")
elif "mistral-" in model or "mixtral-" in model:
logger.info("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4-0613.")
return _num_token_from_messages(messages, model="gpt-4-0613")
logger.info("Mistral.AI models are not supported in tiktoken. Returning num tokens assuming gpt-4.")
return _num_token_from_messages(messages, model="gpt-4")
else:
raise NotImplementedError(
f"""_num_token_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
Expand All @@ -158,7 +146,7 @@ def _num_token_from_messages(messages: Union[List, Dict], model="gpt-3.5-turbo-0
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
num_tokens += 2 # every reply is primed with <im_start>assistant
marklysze marked this conversation as resolved.
Show resolved Hide resolved
return num_tokens


Expand Down
2 changes: 2 additions & 0 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
filter_dict={
"api_type": ["openai"],
"model": [
"o1-preview",
"o1-mini",
"gpt-4o-mini",
"gpt-4o",
"gpt-4-turbo",
Expand Down
2 changes: 2 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
{"model": "gpt-4-32k"},
{"model": "gpt-4o"},
{"model": "gpt-4o-mini"},
{"model": "o1-preview"},
{"model": "o1-mini"},
]


Expand Down
2 changes: 1 addition & 1 deletion test/agentchat/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_update_function():
config_list_gpt4 = autogen.config_list_from_json(
OAI_CONFIG_LIST,
filter_dict={
"tags": ["gpt-4", "gpt-4-32k", "gpt-4o", "gpt-4o-mini"],
"tags": ["gpt-4", "gpt-4-32k", "gpt-4o", "gpt-4o-mini", "o1-preview", "o1-mini"],
marklysze marked this conversation as resolved.
Show resolved Hide resolved
},
file_location=KEY_LOC,
)
Expand Down
Loading