Skip to content

Commit

Permalink
[CM-8803] update open ai integration with open ai sdk version 1 (#93)
Browse files Browse the repository at this point in the history
* Update result parser

* Add deprecation logic

* Update result parsing, add types

* Update types

* Add deprecation warning

* Fix problem when patcher didn't catch deprecation exception

* Remove deprecation warning, make some renamings in parsers, add unit tests for old openai version to a separate module

* Fix lint errors
  • Loading branch information
alexkuzmik authored Nov 9, 2023
1 parent 3877fe0 commit 3cb5dd7
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 7 deletions.
44 changes: 43 additions & 1 deletion src/comet_llm/autologgers/openai/chat_completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,28 @@
# *******************************************************

import inspect
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union

import comet_llm.logging

from . import metadata

if TYPE_CHECKING:
from openai import Stream
from openai.openai_object import OpenAIObject
from openai.types.chat.chat_completion import ChatCompletion

Inputs = Dict[str, Any]
Outputs = Dict[str, Any]
Metadata = Dict[str, Any]

CreateCallResult = Union[
"ChatCompletion", "Stream", "OpenAIObject", Iterable["OpenAIObject"]
]

LOGGER = logging.getLogger(__file__)


def create_arguments_supported(kwargs: Dict[str, Any]) -> bool:
if "messages" not in kwargs:
Expand All @@ -43,7 +56,16 @@ def parse_create_arguments(kwargs: Dict[str, Any]) -> Tuple[Inputs, Metadata]:
return inputs, metadata


def parse_create_result(
def parse_create_result(result: CreateCallResult) -> Tuple[Outputs, Metadata]:
openai_version = metadata.openai_version()

if openai_version is not None and openai_version.startswith("0."):
return _v0_x_x__parse_create_result(result)

return _v1_x_x__parse_create_result(result)


def _v0_x_x__parse_create_result(
result: Union["OpenAIObject", Iterable["OpenAIObject"]]
) -> Tuple[Outputs, Metadata]:
if inspect.isgenerator(result):
Expand All @@ -60,3 +82,23 @@ def parse_create_result(
metadata["output_model"] = metadata.pop("model")

return outputs, metadata


def _v1_x_x__parse_create_result(
result: Union["ChatCompletion", "Stream"]
) -> Tuple[Outputs, Metadata]:
stream_mode = not hasattr(result, "model_dump")
if stream_mode:
choices = "Generation is not logged when using stream mode"
metadata = {}
else:
result_dict = result.model_dump()
choices: List[Dict[str, Any]] = result_dict.pop("choices") # type: ignore
metadata = result_dict

outputs = {"choices": choices}

if "model" in metadata:
metadata["output_model"] = metadata.pop("model")

return outputs, metadata
27 changes: 27 additions & 0 deletions src/comet_llm/autologgers/openai/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this package.
# *******************************************************

import functools
from typing import Optional


@functools.lru_cache(maxsize=1)
def openai_version() -> Optional[str]:
try:
import openai

version: str = openai.__version__
return version
except Exception:
return None
16 changes: 16 additions & 0 deletions src/comet_llm/autologgers/openai/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ def patch(registry: "registry.Registry") -> None:
registry.register_after_exception(
"openai", "ChatCompletion.create", hooks.after_exception_chat_completion_create
)

registry.register_before(
"openai.resources.chat.completions",
"Completions.create",
hooks.before_chat_completion_create,
)
registry.register_after(
"openai.resources.chat.completions",
"Completions.create",
hooks.after_chat_completion_create,
)
registry.register_after_exception(
"openai.resources.chat.completions",
"Completions.create",
hooks.after_exception_chat_completion_create,
)
2 changes: 1 addition & 1 deletion src/comet_llm/import_hooks/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _get_object(module: ModuleType, callable_path: str) -> Any:
for part in callable_path:
try:
current_object = getattr(current_object, part)
except AttributeError:
except Exception:
return None

return current_object
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/autologgers/openai/test_chat_completion_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def test_parse_create_arguments__only_messages_presented():
)


def test_parse_create_result__input_is_openai_object__input_parsed_successfully():
def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully():
create_result = Fake("create_result")
with Scenario() as s:
s.create_result.to_dict() >> {
s.create_result.model_dump() >> {
"choices": "the-choices",
"some-key": "some-value",
}
Expand All @@ -75,10 +75,10 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully(
assert metadata == {"some-key": "some-value"}


def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model():
def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully__model_key_renamed_to_output_model():
create_result = Fake("create_result")
with Scenario() as s:
s.create_result.to_dict() >> {
s.create_result.model_dump() >> {
"choices": "the-choices",
"some-key": "some-value",
"model": "the-model",
Expand All @@ -90,7 +90,7 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully_
assert metadata == {"some-key": "some-value", "output_model": "the-model"}


def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used():
def test_parse_create_result__input_is_Stream__input_parsed_with_hardcoded_values_used():
create_result = (x for x in [])

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import box
import pytest
from testix import *

from comet_llm.autologgers.openai import chat_completion_parsers


@pytest.fixture(autouse=True)
def mock_imports(patch_module):
patch_module(chat_completion_parsers, "metadata")

def test_parse_create_result__input_is_openai_object__input_parsed_successfully():
create_result = Fake("create_result")
with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
s.create_result.to_dict() >> {
"choices": "the-choices",
"some-key": "some-value",
}

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "the-choices"}
assert metadata == {"some-key": "some-value"}


def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model():
create_result = Fake("create_result")
with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
s.create_result.to_dict() >> {
"choices": "the-choices",
"some-key": "some-value",
"model": "the-model",
}

outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "the-choices"}
assert metadata == {"some-key": "some-value", "output_model": "the-model"}


def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used():
create_result = (x for x in [])

with Scenario() as s:
s.metadata.openai_version() >> "0.99.99"
outputs, metadata = chat_completion_parsers.parse_create_result(create_result)

assert outputs == {"choices": "Generation is not logged when using stream mode"}
assert metadata == {}

0 comments on commit 3cb5dd7

Please sign in to comment.