-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
OPIK-615 [SDK] DSPY integration (#940)
* OPIK-615 [SDK] DSPY integration * wip * wip * wip * wip * new UI representation * new UI representation fix * detect provider+model * detect span type * fix linter * add integration tests * handle changing spans order * handle api_key for dspy tests --------- Co-authored-by: Aliaksandr Kuzmik <[email protected]>
- Loading branch information
1 parent
f3b278a
commit 7537149
Showing
8 changed files
with
786 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Workflow to run DSPy tests | ||
# | ||
# Please read inputs to provide correct values. | ||
# | ||
name: SDK Lib DSPy Tests | ||
run-name: "SDK Lib DSPy Tests ${{ github.ref_name }} by @${{ github.actor }}" | ||
env: | ||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} | ||
OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }} | ||
on: | ||
workflow_call: | ||
|
||
jobs: | ||
tests: | ||
name: DSPy Python ${{matrix.python_version}} | ||
runs-on: ubuntu-latest | ||
defaults: | ||
run: | ||
working-directory: sdks/python | ||
|
||
strategy: | ||
fail-fast: true | ||
matrix: | ||
python_version: ["3.9", "3.10", "3.11", "3.12"] | ||
|
||
steps: | ||
- name: Check out code | ||
uses: actions/checkout@v4 | ||
|
||
- name: Setup Python ${{matrix.python_version}} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{matrix.python_version}} | ||
|
||
- name: Install opik | ||
run: pip install . | ||
|
||
- name: Install test tools | ||
run: | | ||
cd ./tests | ||
pip install --no-cache-dir --disable-pip-version-check -r test_requirements.txt | ||
- name: Install lib | ||
run: | | ||
cd ./tests | ||
pip install --no-cache-dir --disable-pip-version-check -r library_integration/dspy/requirements.txt | ||
- name: Run tests | ||
run: | | ||
cd ./tests/library_integration/dspy/ | ||
python -m pytest -vv . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
from contextvars import ContextVar, Token | ||
from typing import Any, Dict, Optional, Union | ||
|
||
import dspy | ||
from dspy.utils.callback import BaseCallback | ||
|
||
from opik import opik_context | ||
from opik.api_objects import helpers, span, trace | ||
from opik.api_objects.opik_client import get_client_cached | ||
from opik.decorator import error_info_collector | ||
|
||
ContextType = Union[span.SpanData, trace.TraceData] | ||
|
||
|
||
class OpikCallback(BaseCallback): | ||
def __init__( | ||
self, | ||
project_name: Optional[str] = None, | ||
): | ||
self._map_call_id_to_span_data: Dict[str, span.SpanData] = {} | ||
self._map_call_id_to_trace_data: Dict[str, trace.TraceData] = {} | ||
self._map_span_id_or_trace_id_to_token: Dict[str, Token] = {} | ||
|
||
self._current_callback_context: ContextVar[Optional[ContextType]] = ContextVar( | ||
"opik_context", default=None | ||
) | ||
|
||
self._project_name = project_name | ||
|
||
self._opik_client = get_client_cached() | ||
|
||
def on_module_start( | ||
self, | ||
call_id: str, | ||
instance: Any, | ||
inputs: Dict[str, Any], | ||
) -> None: | ||
if current_callback_context_data := self._current_callback_context.get(): | ||
if isinstance(current_callback_context_data, span.SpanData): | ||
self._attach_span_to_existing_span( | ||
call_id=call_id, | ||
current_span_data=current_callback_context_data, | ||
instance=instance, | ||
inputs=inputs, | ||
) | ||
else: | ||
self._attach_span_to_existing_trace( | ||
call_id=call_id, | ||
current_trace_data=current_callback_context_data, | ||
instance=instance, | ||
inputs=inputs, | ||
) | ||
return | ||
|
||
if current_span_data := opik_context.get_current_span_data(): | ||
self._attach_span_to_existing_span( | ||
call_id=call_id, | ||
current_span_data=current_span_data, | ||
instance=instance, | ||
inputs=inputs, | ||
) | ||
new_span_data = self._map_call_id_to_span_data[call_id] | ||
self._callback_context_set(new_span_data) | ||
return | ||
|
||
if current_trace_data := opik_context.get_current_trace_data(): | ||
self._attach_span_to_existing_trace( | ||
call_id=call_id, | ||
current_trace_data=current_trace_data, | ||
instance=instance, | ||
inputs=inputs, | ||
) | ||
new_span_data = self._map_call_id_to_span_data[call_id] | ||
self._callback_context_set(new_span_data) | ||
return | ||
|
||
self._start_trace( | ||
call_id=call_id, | ||
instance=instance, | ||
inputs=inputs, | ||
) | ||
|
||
def _attach_span_to_existing_span( | ||
self, | ||
call_id: str, | ||
current_span_data: span.SpanData, | ||
instance: Any, | ||
inputs: Dict[str, Any], | ||
) -> None: | ||
project_name = helpers.resolve_child_span_project_name( | ||
parent_project_name=current_span_data.project_name, | ||
child_project_name=self._project_name, | ||
) | ||
span_type = self._get_span_type(instance) | ||
|
||
span_data = span.SpanData( | ||
trace_id=current_span_data.trace_id, | ||
parent_span_id=current_span_data.id, | ||
name=instance.__class__.__name__, | ||
input=inputs, | ||
type=span_type, | ||
project_name=project_name, | ||
) | ||
self._map_call_id_to_span_data[call_id] = span_data | ||
|
||
def _attach_span_to_existing_trace( | ||
self, | ||
call_id: str, | ||
current_trace_data: trace.TraceData, | ||
instance: Any, | ||
inputs: Dict[str, Any], | ||
) -> None: | ||
project_name = helpers.resolve_child_span_project_name( | ||
current_trace_data.project_name, | ||
self._project_name, | ||
) | ||
span_type = self._get_span_type(instance) | ||
|
||
span_data = span.SpanData( | ||
trace_id=current_trace_data.id, | ||
parent_span_id=None, | ||
name=instance.__class__.__name__, | ||
input=inputs, | ||
type=span_type, | ||
project_name=project_name, | ||
) | ||
self._map_call_id_to_span_data[call_id] = span_data | ||
|
||
def _start_trace( | ||
self, | ||
call_id: str, | ||
instance: Any, | ||
inputs: Dict[str, Any], | ||
) -> None: | ||
trace_data = trace.TraceData( | ||
name=instance.__class__.__name__, | ||
input=inputs, | ||
metadata={"created_from": "dspy"}, | ||
project_name=self._project_name, | ||
) | ||
self._map_call_id_to_trace_data[call_id] = trace_data | ||
self._callback_context_set(trace_data) | ||
|
||
def on_module_end( | ||
self, | ||
call_id: str, | ||
outputs: Optional[Any], | ||
exception: Optional[Exception] = None, | ||
) -> None: | ||
self._end_span( | ||
call_id=call_id, | ||
exception=exception, | ||
outputs=outputs, | ||
) | ||
self._end_trace(call_id=call_id) | ||
|
||
def _end_trace(self, call_id: str) -> None: | ||
if trace_data := self._map_call_id_to_trace_data.pop(call_id, None): | ||
trace_data.init_end_time() | ||
self._opik_client.trace(**trace_data.__dict__) | ||
|
||
# remove trace data from context | ||
if token := self._map_span_id_or_trace_id_to_token.pop(trace_data.id, None): | ||
self._current_callback_context.reset(token) | ||
|
||
def _end_span( | ||
self, | ||
call_id: str, | ||
outputs: Optional[Any], | ||
exception: Optional[Exception] = None, | ||
) -> None: | ||
if span_data := self._map_call_id_to_span_data.pop(call_id, None): | ||
if exception: | ||
error_info = error_info_collector.collect(exception) | ||
span_data.update(error_info=error_info) | ||
|
||
span_data.update(output={"output": outputs}).init_end_time() | ||
self._opik_client.span(**span_data.__dict__) | ||
|
||
# remove span data from context | ||
if token := self._map_span_id_or_trace_id_to_token.pop(span_data.id, None): | ||
self._current_callback_context.reset(token) | ||
|
||
def on_lm_start( | ||
self, | ||
call_id: str, | ||
instance: Any, | ||
inputs: Dict[str, Any], | ||
) -> None: | ||
current_callback_context_data = self._current_callback_context.get() | ||
assert current_callback_context_data is not None | ||
|
||
project_name = helpers.resolve_child_span_project_name( | ||
current_callback_context_data.project_name, | ||
self._project_name, | ||
) | ||
|
||
if isinstance(current_callback_context_data, span.SpanData): | ||
trace_id = current_callback_context_data.trace_id | ||
parent_span_id = current_callback_context_data.id | ||
else: | ||
trace_id = current_callback_context_data.id | ||
parent_span_id = None | ||
|
||
provider, model = instance.model.split(r"/", 1) | ||
span_type = self._get_span_type(instance) | ||
|
||
span_data = span.SpanData( | ||
trace_id=trace_id, | ||
name=instance.__class__.__name__, | ||
parent_span_id=parent_span_id, | ||
type=span_type, | ||
input=inputs, | ||
project_name=project_name, | ||
provider=provider, | ||
model=model, | ||
) | ||
self._map_call_id_to_span_data[call_id] = span_data | ||
|
||
def on_lm_end( | ||
self, | ||
call_id: str, | ||
outputs: Optional[Dict[str, Any]], | ||
exception: Optional[Exception] = None, | ||
) -> None: | ||
self._end_span( | ||
call_id=call_id, | ||
exception=exception, | ||
outputs=outputs, | ||
) | ||
self._end_trace(call_id=call_id) | ||
|
||
def flush(self) -> None: | ||
"""Sends pending Opik data to the backend""" | ||
self._opik_client.flush() | ||
|
||
def _callback_context_set(self, value: ContextType) -> None: | ||
token = self._current_callback_context.set(value) | ||
self._map_span_id_or_trace_id_to_token[value.id] = token | ||
|
||
def _get_span_type(self, instance: Any) -> span.SpanType: | ||
if isinstance(instance, dspy.Predict): | ||
return "llm" | ||
elif isinstance(instance, dspy.LM): | ||
return "llm" | ||
return "general" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
dspy |
Oops, something went wrong.