Skip to content

Commit

Permalink
[CM-8549] added error handling for chains and made raise_exception Fa…
Browse files Browse the repository at this point in the history
…lse by def… (#90)

* added error handling for chains and made raise_exception False by default

* fixed linting errors

* added unit tests

* fixed linting

* fixed review and updated unit tests

* fixed lint

* update unit tests
  • Loading branch information
jynx10 authored Nov 15, 2023
1 parent 0731a02 commit 69dbd7c
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 20 deletions.
7 changes: 7 additions & 0 deletions src/comet_llm/chains/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
experiment_api,
experiment_info,
llm_result,
logging_messages,
)
from ..types import JSONEncodable
from . import chain, state
Expand Down Expand Up @@ -72,6 +73,7 @@ def start_chain(
state.set_global_chain(global_chain)


@exceptions.filter(allow_raising=config.raising_enabled(), summary=app.SUMMARY)
def end_chain(
outputs: Dict[str, JSONEncodable],
metadata: Optional[Dict[str, JSONEncodable]] = None,
Expand All @@ -89,6 +91,11 @@ def end_chain(
Returns: LLMResult
"""
global_chain = state.get_global_chain()
if global_chain is None:
raise exceptions.CometLLMException(
logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`end_chain`"
)

global_chain.set_outputs(outputs=outputs, metadata=metadata)
return log_chain(global_chain)

Expand Down
28 changes: 24 additions & 4 deletions src/comet_llm/chains/span.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# LICENSE file in the root directory of this package.
# *******************************************************

import logging
from typing import TYPE_CHECKING, Dict, List, Optional

from .. import datetimes
from comet_llm import logging as comet_logging

from .. import config, datetimes, exceptions, logging_messages
from ..types import JSONEncodable
from . import deepmerge, state

if TYPE_CHECKING:
from . import chain

LOGGER = logging.getLogger(__name__)


class Span:
"""
Expand Down Expand Up @@ -50,6 +55,7 @@ def __init__(
self._metadata = metadata if metadata is not None else {}
self._outputs: Optional[Dict[str, JSONEncodable]] = None
self._context: Optional[List[int]] = None
self._chain: Optional["chain.Chain"] = None

self._id = state.get_new_id()
self._name = name if name is not None else "unnamed"
Expand All @@ -76,21 +82,35 @@ def name(self) -> str: # pragma: no cover
def __enter__(self) -> "Span":
chain = state.get_global_chain()

if chain is None:
chain_not_initialized_exception = exceptions.CometLLMException(
logging_messages.GLOBAL_CHAIN_NOT_INITIALIZED % "`Span`"
)
if config.raising_enabled():
raise chain_not_initialized_exception

comet_logging.log_once_at_level(
LOGGER, logging.ERROR, str(chain_not_initialized_exception)
)

return self

self.__api__start__(chain)
return self

def __api__start__(self, chain: "chain.Chain") -> None:
self._connect_to_chain(chain)

self._timer.start()
self._chain.context.add(self.id)
self._chain.context.add(self.id) # type: ignore

def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore
self.__api__end__()

def __api__end__(self) -> None:
self._timer.stop()
self._chain.context.pop()
if self._chain is not None:
self._timer.stop()
self._chain.context.pop()

def set_outputs(
self,
Expand Down
8 changes: 2 additions & 6 deletions src/comet_llm/chains/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# LICENSE file in the root directory of this package.
# *******************************************************

import inspect
import threading
from typing import TYPE_CHECKING, Dict, Optional

from .. import exceptions
from .. import app, config, exceptions
from . import thread_context_registry

if TYPE_CHECKING: # pragma: no cover
Expand All @@ -34,11 +35,6 @@ def chain_exists(self) -> bool:
@property
def chain(self) -> "chain.Chain":
result: "chain.Chain" = self._thread_context_registry.get("global-chain")
if result is None:
raise exceptions.CometLLMException(
"Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)`"
)

return result

@chain.setter
Expand Down
2 changes: 1 addition & 1 deletion src/comet_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _extend_comet_ml_config() -> None:
CONFIG_MAP_EXTENSION = {
"comet.disable": {"type": int, "default": 0},
"comet.logging.console": {"type": str, "default": "INFO"},
"comet.raise_exceptions_on_error": {"type": int, "default": 1},
"comet.raise_exceptions_on_error": {"type": int, "default": 0},
}

comet_ml_config.CONFIG_MAP.update(CONFIG_MAP_EXTENSION)
Expand Down
2 changes: 2 additions & 0 deletions src/comet_llm/logging_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@
)

INVALID_TIMESTAMP = "Invalid timestamp: %s. Timestamp must be in seconds if specified."

GLOBAL_CHAIN_NOT_INITIALIZED = "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use %s"
1 change: 1 addition & 0 deletions tests/unit/chains/test_chains_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,4 @@ def test_end_chain__happyflow():
)

assert result == llm_result.LLMResult(id="experiment-id", project_url="project-url")

58 changes: 58 additions & 0 deletions tests/unit/chains/test_span.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from testix import saveargument

from comet_llm.chains import span
from comet_llm.exceptions import CometLLMException


@pytest.fixture(autouse=True)
def mock_imports(patch_module):
patch_module(span, "state")
patch_module(span, "datetimes")
patch_module(span, "convert")
patch_module(span, "comet_logging")
patch_module(span, "config")
patch_module(span, "LOGGER", "logger")


def _construct(
Expand Down Expand Up @@ -181,3 +185,57 @@ def test_set_output__new_metadata_is_not_None__existing_metadata_is_merged_with_
"existing-key": "existing-value",
"new-key": "new-value",
}


def test_span__no_chain_started_raising_exceptions_disabled__wont_connect_to_chain():
with Scenario() as s:
s.state.get_new_id() >> "example_id"
timer = box.Box(duration=None, start_timestamp=None, end_timestamp=None)

s.datetimes.Timer() >> timer

s.state.get_global_chain() >> None

s.config.raising_enabled() >> False

s.comet_logging.log_once_at_level(
"logger", 40, "Global chain is not initialized for this thread. Initialize it with `comet_llm.start_chain(...)` if you wish to use `Span`"
)

with span.Span(
category="llm-call",
inputs={"input": "input"},
) as tested_span:
tested_span.set_outputs({"outputs": "outputs"})

assert tested_span.as_dict() == {
"id": "example_id",
"category": "llm-call",
"name": "unnamed",
"inputs": {"input": "input"},
"outputs": {"outputs": "outputs"},
"duration": None,
"start_timestamp": None,
"end_timestamp": None,
"parent_ids": None,
"metadata": {},
}


def test_span__no_chain_started_raising_exceptions_enabled__exception_raised():
with Scenario() as s:
s.state.get_new_id() >> "example_id"
timer = Fake("timer")

s.datetimes.Timer() >> timer

s.state.get_global_chain() >> None

s.config.raising_enabled() >> True

with pytest.raises(CometLLMException):
with span.Span(
category="llm-call",
inputs={"input": "input"},
) as tested_span:
tested_span.set_outputs({"outputs": "outputs"})
9 changes: 0 additions & 9 deletions tests/unit/chains/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@ def test_new_id__happyflow():
assert tested.new_id() == 2


def test_chain_property_chain_was_not_set__exception_raised():
tested = _construct()

with Scenario() as s:
s.registry.get("global-chain") >> None
with pytest.raises(exceptions.CometLLMException):
tested.chain


def test_chain_exists__chain_was_not_set__returned_False():
tested = _construct()

Expand Down

0 comments on commit 69dbd7c

Please sign in to comment.