Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CM-8803] update open ai integration with open ai sdk version 1 #93

Merged
44 changes: 43 additions & 1 deletion src/comet_llm/autologgers/openai/chat_completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@
# *******************************************************

import inspect
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union

import comet_llm.logging

from . import metadata

if TYPE_CHECKING:
from openai import Stream
from openai.openai_object import OpenAIObject
from openai.types.chat.chat_completion import ChatCompletion

Inputs = Dict[str, Any]
Outputs = Dict[str, Any]
Metadata = Dict[str, Any]

CreateCallResult = Union[
"ChatCompletion", "Stream", "OpenAIObject", Iterable["OpenAIObject"]
]

LOGGER = logging.getLogger(__file__)


def create_arguments_supported(kwargs: Dict[str, Any]) -> bool:
if "messages" not in kwargs:
Expand All @@ -43,7 +56,16 @@ def parse_create_arguments(kwargs: Dict[str, Any]) -> Tuple[Inputs, Metadata]:
return inputs, metadata


def parse_create_result(
def parse_create_result(result: CreateCallResult) -> Tuple[Outputs, Metadata]:
openai_version = metadata.openai_version()

if openai_version is not None and openai_version.startswith("0."):
return _v0_x_x__parse_create_result(result)

return _v1_x_x__parse_create_result(result)


def _v0_x_x__parse_create_result(
result: Union["OpenAIObject", Iterable["OpenAIObject"]]
) -> Tuple[Outputs, Metadata]:
if inspect.isgenerator(result):
Expand All @@ -60,3 +82,23 @@ def parse_create_result(
metadata["output_model"] = metadata.pop("model")

return outputs, metadata


def _v1_x_x__parse_create_result(
result: Union["ChatCompletion", "Stream"]
) -> Tuple[Outputs, Metadata]:
stream_mode = not hasattr(result, "model_dump")
if stream_mode:
choices = "Generation is not logged when using stream mode"
metadata = {}
else:
result_dict = result.model_dump()
choices: List[Dict[str, Any]] = result_dict.pop("choices") # type: ignore
metadata = result_dict

outputs = {"choices": choices}

if "model" in metadata:
metadata["output_model"] = metadata.pop("model")

return outputs, metadata
27 changes: 27 additions & 0 deletions src/comet_llm/autologgers/openai/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

import functools
from typing import Optional


@functools.lru_cache(maxsize=1)
def openai_version() -> Optional[str]:
try:
import openai

version: str = openai.__version__
return version
except Exception:
return None
16 changes: 16 additions & 0 deletions src/comet_llm/autologgers/openai/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ def patch(registry: "registry.Registry") -> None:
registry.register_after_exception(
"openai", "ChatCompletion.create", hooks.after_exception_chat_completion_create
)

registry.register_before(
"openai.resources.chat.completions",
"Completions.create",
hooks.before_chat_completion_create,
)
registry.register_after(
"openai.resources.chat.completions",
"Completions.create",
hooks.after_chat_completion_create,
)
registry.register_after_exception(
"openai.resources.chat.completions",
"Completions.create",
hooks.after_exception_chat_completion_create,
)
2 changes: 1 addition & 1 deletion src/comet_llm/import_hooks/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _get_object(module: ModuleType, callable_path: str) -> Any:
for part in callable_path:
try:
current_object = getattr(current_object, part)
except AttributeError:
except Exception:
return None

return current_object
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/autologgers/openai/test_chat_completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def test_parse_create_arguments__only_messages_presented():
)


def test_parse_create_result__input_is_openai_object__input_parsed_successfully():
def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully():
create_result = Fake("create_result")
with Scenario() as s:
s.create_result.to_dict() >> {
s.create_result.model_dump() >> {
"choices": "the-choices",
"some-key": "some-value",
}
Expand All @@ -75,10 +75,10 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully(
assert metadata == {"some-key": "some-value"}


def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model():
def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully__model_key_renamed_to_output_model():
create_result = Fake("create_result")
with Scenario() as s:
s.create_result.to_dict() >> {
s.create_result.model_dump() >> {
"choices": "the-choices",
"some-key": "some-value",
"model": "the-model",
Expand All @@ -90,7 +90,7 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully_
assert metadata == {"some-key": "some-value", "output_model": "the-model"}


def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used():
def test_parse_create_result__input_is_Stream__input_parsed_with_hardcoded_values_used():
create_result = (x for x in [])

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import box
import pytest
from testix import *

from comet_llm.autologgers.openai import chat_completion_parsers


@pytest.fixture(autouse=True)
def mock_imports(patch_module):
patch_module(chat_completion_parsers, "metadata")

def test_parse_create_result__input_is_openai_object__input_parsed_successfully():
create_result = Fake("create_result")
with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
s.create_result.to_dict() >> {
"choices": "the-choices",
"some-key": "some-value",
}

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "the-choices"}
assert metadata == {"some-key": "some-value"}


def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model():
create_result = Fake("create_result")
with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
s.create_result.to_dict() >> {
"choices": "the-choices",
"some-key": "some-value",
"model": "the-model",
}

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "the-choices"}
assert metadata == {"some-key": "some-value", "output_model": "the-model"}


def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used():
create_result = (x for x in [])

with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "Generation is not logged when using stream mode"}
assert metadata == {}