diff --git a/src/comet_llm/autologgers/openai/chat_completion_parsers.py b/src/comet_llm/autologgers/openai/chat_completion_parsers.py index cf39eef702..9c233ed577 100644 --- a/src/comet_llm/autologgers/openai/chat_completion_parsers.py +++ b/src/comet_llm/autologgers/openai/chat_completion_parsers.py @@ -13,15 +13,28 @@ # ******************************************************* import inspect +import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union +import comet_llm.logging + +from . import metadata + if TYPE_CHECKING: + from openai import Stream from openai.openai_object import OpenAIObject + from openai.types.chat.chat_completion import ChatCompletion Inputs = Dict[str, Any] Outputs = Dict[str, Any] Metadata = Dict[str, Any] +CreateCallResult = Union[ + "ChatCompletion", "Stream", "OpenAIObject", Iterable["OpenAIObject"] +] + +LOGGER = logging.getLogger(__file__) + def create_arguments_supported(kwargs: Dict[str, Any]) -> bool: if "messages" not in kwargs: @@ -43,7 +56,16 @@ def parse_create_arguments(kwargs: Dict[str, Any]) -> Tuple[Inputs, Metadata]: return inputs, metadata -def parse_create_result( +def parse_create_result(result: CreateCallResult) -> Tuple[Outputs, Metadata]: + openai_version = metadata.openai_version() + + if openai_version is not None and openai_version.startswith("0."): + return _v0_x_x__parse_create_result(result) + + return _v1_x_x__parse_create_result(result) + + +def _v0_x_x__parse_create_result( result: Union["OpenAIObject", Iterable["OpenAIObject"]] ) -> Tuple[Outputs, Metadata]: if inspect.isgenerator(result): @@ -60,3 +82,23 @@ def parse_create_result( metadata["output_model"] = metadata.pop("model") return outputs, metadata + + +def _v1_x_x__parse_create_result( + result: Union["ChatCompletion", "Stream"] +) -> Tuple[Outputs, Metadata]: + stream_mode = not hasattr(result, "model_dump") + if stream_mode: + choices = "Generation is not logged when using stream mode" + metadata = {} + else: + result_dict = result.model_dump() + choices: List[Dict[str, Any]] = result_dict.pop("choices") # type: ignore + metadata = result_dict + + outputs = {"choices": choices} + + if "model" in metadata: + metadata["output_model"] = metadata.pop("model") + + return outputs, metadata diff --git a/src/comet_llm/autologgers/openai/metadata.py b/src/comet_llm/autologgers/openai/metadata.py new file mode 100644 index 0000000000..e8971e413f --- /dev/null +++ b/src/comet_llm/autologgers/openai/metadata.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this package. +# ******************************************************* + +import functools +from typing import Optional + + +@functools.lru_cache(maxsize=1) +def openai_version() -> Optional[str]: + try: + import openai + + version: str = openai.__version__ + return version + except Exception: + return None diff --git a/src/comet_llm/autologgers/openai/patcher.py b/src/comet_llm/autologgers/openai/patcher.py index ade8a8ec42..f3870b3777 100644 --- a/src/comet_llm/autologgers/openai/patcher.py +++ b/src/comet_llm/autologgers/openai/patcher.py @@ -30,3 +30,19 @@ def patch(registry: "registry.Registry") -> None: registry.register_after_exception( "openai", "ChatCompletion.create", hooks.after_exception_chat_completion_create ) + + registry.register_before( + "openai.resources.chat.completions", + "Completions.create", + hooks.before_chat_completion_create, + ) + registry.register_after( + "openai.resources.chat.completions", + "Completions.create", + hooks.after_chat_completion_create, + ) + registry.register_after_exception( + "openai.resources.chat.completions", + "Completions.create", + hooks.after_exception_chat_completion_create, + ) diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py index 19cdd44d1e..4b13135d02 100644 --- a/src/comet_llm/import_hooks/patcher.py +++ b/src/comet_llm/import_hooks/patcher.py @@ -30,7 +30,7 @@ def _get_object(module: ModuleType, callable_path: str) -> Any: for part in callable_path: try: current_object = getattr(current_object, part) - except AttributeError: + except Exception: return None return current_object diff --git a/tests/unit/autologgers/openai/test_chat_completion_parsers.py b/tests/unit/autologgers/openai/test_chat_completion_parsers.py index 8ca3beb90d..5fa7e4e8d9 100644 --- a/tests/unit/autologgers/openai/test_chat_completion_parsers.py +++ b/tests/unit/autologgers/openai/test_chat_completion_parsers.py @@ -61,10 +61,10 @@ def test_parse_create_arguments__only_messages_presented(): ) -def test_parse_create_result__input_is_openai_object__input_parsed_successfully(): +def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully(): create_result = Fake("create_result") with Scenario() as s: - s.create_result.to_dict() >> { + s.create_result.model_dump() >> { "choices": "the-choices", "some-key": "some-value", } @@ -75,10 +75,10 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully( assert metadata == {"some-key": "some-value"} -def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model(): +def test_parse_create_result__input_is_ChatCompletion__input_parsed_successfully__model_key_renamed_to_output_model(): create_result = Fake("create_result") with Scenario() as s: - s.create_result.to_dict() >> { + s.create_result.model_dump() >> { "choices": "the-choices", "some-key": "some-value", "model": "the-model", @@ -90,7 +90,7 @@ def test_parse_create_result__input_is_openai_object__input_parsed_successfully_ assert metadata == {"some-key": "some-value", "output_model": "the-model"} -def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used(): +def test_parse_create_result__input_is_Stream__input_parsed_with_hardcoded_values_used(): create_result = (x for x in []) outputs, metadata = chat_completion_parsers.parse_create_result(create_result) diff --git a/tests/unit/autologgers/openai/test_chat_completion_parsers_openai_v0.py b/tests/unit/autologgers/openai/test_chat_completion_parsers_openai_v0.py new file mode 100644 index 0000000000..ab81c61ac9 --- /dev/null +++ b/tests/unit/autologgers/openai/test_chat_completion_parsers_openai_v0.py @@ -0,0 +1,51 @@ +import box +import pytest +from testix import * + +from comet_llm.autologgers.openai import chat_completion_parsers + + +@pytest.fixture(autouse=True) +def mock_imports(patch_module): + patch_module(chat_completion_parsers, "metadata") + +def test_parse_create_result__input_is_openai_object__input_parsed_successfully(): + create_result = Fake("create_result") + with Scenario() as s: + s.metadata.openai_version() >> "0.99.99" + s.create_result.to_dict() >> { + "choices": "the-choices", + "some-key": "some-value", + } + + outputs, metadata = chat_completion_parsers.parse_create_result(create_result) + + assert outputs == {"choices": "the-choices"} + assert metadata == {"some-key": "some-value"} + + +def test_parse_create_result__input_is_openai_object__input_parsed_successfully__model_key_renamed_to_output_model(): + create_result = Fake("create_result") + with Scenario() as s: + s.metadata.openai_version() >> "0.99.99" + s.create_result.to_dict() >> { + "choices": "the-choices", + "some-key": "some-value", + "model": "the-model", + } + + outputs, metadata = chat_completion_parsers.parse_create_result(create_result) + + assert outputs == {"choices": "the-choices"} + assert metadata == {"some-key": "some-value", "output_model": "the-model"} + + +def test_parse_create_result__input_is_generator_object__input_parsed_with_hardcoded_values_used(): + create_result = (x for x in []) + + with Scenario() as s: + s.metadata.openai_version() >> "0.99.99" + outputs, metadata = chat_completion_parsers.parse_create_result(create_result) + + assert outputs == {"choices": "Generation is not logged when using stream mode"} + assert metadata == {}