diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 35705f2b0fc0..86abf3de2e4c 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,11 +8,14 @@ from flaml.automl.logger import logger_formatter from autogen.oai.openai_utils import get_key +from autogen.token_count_utils import count_token try: from openai import OpenAI, APIError from openai.types.chat import ChatCompletion + from openai.types.chat.chat_completion import ChatCompletionMessage, Choice from openai.types.completion import Completion + from openai.types.completion_usage import CompletionUsage import diskcache ERROR = None @@ -233,9 +236,8 @@ def yes_or_no_filter(context, response): response.pass_filter = pass_filter # TODO: add response.cost return response - completions = client.chat.completions if "messages" in params else client.completions try: - response = completions.create(**params) + response = self._completions_create(client, params) except APIError: logger.debug(f"config {i} failed", exc_info=1) if i == last: @@ -246,6 +248,67 @@ def yes_or_no_filter(context, response): cache.set(key, response) return response + def _completions_create(self, client, params): + completions = client.chat.completions if "messages" in params else client.completions + # If streaming is enabled, has messages, and does not have functions, then + # iterate over the chunks of the response + if params.get("stream", False) and "messages" in params and "functions" not in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + print("\033[32m", end="") + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + finish_reasons[choice.index] = choice.finish_reason + # If content is present, print it to the terminal and update response variables + if content is not None: + print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + print() + + # Reset the terminal text color + print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API + prompt_tokens = count_token(params["messages"], model) + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + response.choices.append( + Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", content=response_contents[i], function_call=None + ), + ) + ) + else: + # If streaming is not enabled, send a regular chat completion request + # Ensure streaming is disabled + params["stream"] = False + response = completions.create(**params) + return response + @classmethod def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]: """Extract the text or function calls from a completion or chat response.