Skip to content

Commit

Permalink
Implement batching for feedback scores, update and add new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkuzmik committed Jan 9, 2025
1 parent f4dadf7 commit f553816
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 12 deletions.
13 changes: 7 additions & 6 deletions sdks/python/src/opik/message_processing/batching/base_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ def __init__(
self._last_time_flush_callback_called: float = time.time()
self._lock = threading.RLock()

def add(self, message: messages.BaseMessage) -> None:
with self._lock:
self._accumulated_messages.append(message)
if len(self._accumulated_messages) == self._max_batch_size:
self.flush()

def flush(self) -> None:
with self._lock:
if len(self._accumulated_messages) > 0:
Expand All @@ -47,3 +41,10 @@ def is_empty(self) -> bool:

@abc.abstractmethod
def _create_batch_from_accumulated_messages(self) -> messages.BaseMessage: ...

@abc.abstractmethod
def add(self, message: messages.BaseMessage) -> None:
with self._lock:
self._accumulated_messages.append(message)
if len(self._accumulated_messages) == self._max_batch_size:
self.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def stop(self) -> None:
self._flushing_thread.close()

def message_supports_batching(self, message: messages.BaseMessage) -> bool:
if hasattr(message, "supports_batching"):
return message.supports_batching

return message.__class__ in self._message_to_batcher_mapping

def process_message(self, message: messages.BaseMessage) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000

FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000


def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager:
create_span_message_batcher_ = batchers.CreateSpanMessageBatcher(
Expand All @@ -24,11 +27,25 @@ def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManag
flush_callback=message_queue.put,
)

add_span_feedback_scores_batch_message_batcher = batchers.AddSpanFeedbackScoresBatchMessageBatcher(
flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

add_trace_feedback_scores_batch_message_batcher = batchers.AddTraceFeedbackScoresBatchMessageBatcher(
flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

message_to_batcher_mapping: Dict[
Type[messages.BaseMessage], base_batcher.BaseBatcher
] = {
messages.CreateSpanMessage: create_span_message_batcher_,
messages.CreateTraceMessage: create_trace_message_batcher_,
messages.AddSpanFeedbackScoresBatchMessage: add_span_feedback_scores_batch_message_batcher,
messages.AddTraceFeedbackScoresBatchMessage: add_trace_feedback_scores_batch_message_batcher,
}

batch_manager_ = batch_manager.BatchManager(
Expand Down
70 changes: 70 additions & 0 deletions sdks/python/src/opik/message_processing/batching/batchers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from . import base_batcher
from .. import messages

Expand All @@ -8,9 +10,77 @@ def _create_batch_from_accumulated_messages(
) -> messages.CreateSpansBatchMessage:
return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore

def add(self, message: messages.CreateSpansBatchMessage) -> None: # type: ignore
return super().add(message)


class CreateTraceMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages(
self,
) -> messages.CreateTraceBatchMessage:
return messages.CreateTraceBatchMessage(batch=self._accumulated_messages) # type: ignore

def add(self, message: messages.CreateTraceBatchMessage) -> None: # type: ignore
return super().add(message)


class BaseAddFeedbackScoresBatchMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages( # type: ignore
self,
) -> Union[
messages.AddSpanFeedbackScoresBatchMessage,
messages.AddTraceFeedbackScoresBatchMessage,
]:
return super()._create_batch_from_accumulated_messages() # type: ignore

def add( # type: ignore
self,
message: Union[
messages.AddSpanFeedbackScoresBatchMessage,
messages.AddTraceFeedbackScoresBatchMessage,
],
) -> None:
with self._lock:
new_messages = message.batch
n_new_messages = len(new_messages)
n_accumulated_messages = len(self._accumulated_messages)

if n_new_messages + n_accumulated_messages >= self._max_batch_size:
free_space_in_accumulator = (
self._max_batch_size - n_accumulated_messages
)

messages_that_fit_in_batch = new_messages[:free_space_in_accumulator]
messages_that_dont_fit_in_batch = new_messages[
free_space_in_accumulator:
]

self._accumulated_messages += messages_that_fit_in_batch
new_messages = messages_that_dont_fit_in_batch
self.flush()

self._accumulated_messages += new_messages


class AddSpanFeedbackScoresBatchMessageBatcher(
BaseAddFeedbackScoresBatchMessageBatcher
):
def _create_batch_from_accumulated_messages(
self,
) -> messages.AddSpanFeedbackScoresBatchMessage: # type: ignore
return messages.AddSpanFeedbackScoresBatchMessage(
batch=self._accumulated_messages, # type: ignore
supports_batching=False,
)


class AddTraceFeedbackScoresBatchMessageBatcher(
BaseAddFeedbackScoresBatchMessageBatcher
):
def _create_batch_from_accumulated_messages(
self,
) -> messages.AddTraceFeedbackScoresBatchMessage: # type: ignore
return messages.AddTraceFeedbackScoresBatchMessage(
batch=self._accumulated_messages, # type: ignore
supports_batching=False,
)
4 changes: 2 additions & 2 deletions sdks/python/src/opik/message_processing/message_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _process_add_span_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Batch of spans feedbacks scores request: %s", scores)
LOGGER.debug("Add spans feedbacks scores request of size: %d", len(scores))

self._rest_client.spans.score_batch_of_spans(
scores=scores,
Expand All @@ -152,7 +152,7 @@ def _process_add_trace_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Batch of traces feedbacks scores request: %s", scores)
LOGGER.debug("Add traces feedbacks scores request: %d", len(scores))

self._rest_client.traces.score_batch_of_traces(
scores=scores,
Expand Down
17 changes: 14 additions & 3 deletions sdks/python/src/opik/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,24 @@ class FeedbackScoreMessage(BaseMessage):


@dataclasses.dataclass
class AddTraceFeedbackScoresBatchMessage(BaseMessage):
class AddFeedbackScoresBatchMessage(BaseMessage):
batch: List[FeedbackScoreMessage]
supports_batching: bool = True

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data.pop("supports_batching")
return data


@dataclasses.dataclass
class AddSpanFeedbackScoresBatchMessage(BaseMessage):
batch: List[FeedbackScoreMessage]
class AddTraceFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage):
pass


@dataclasses.dataclass
class AddSpanFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage):
pass


@dataclasses.dataclass
Expand Down
15 changes: 14 additions & 1 deletion sdks/python/tests/e2e/test_feedback_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def f_outer():
"value": 0.75,
"category_name": "trace-score-category",
"reason": "trace-score-reason",
}
},
{
"name": "trace-feedback-score-2",
"value": 0.5,
"category_name": "trace-score-category-2",
"reason": "trace-score-reason-2",
},
]
)

Expand All @@ -193,6 +199,13 @@ def f_outer():
"category_name": "trace-score-category",
"reason": "trace-score-reason",
},
{
"id": ID_STORAGE["f_outer-trace-id"],
"name": "trace-feedback-score-2",
"value": 0.5,
"category_name": "trace-score-category-2",
"reason": "trace-score-reason-2",
},
]

verifiers.verify_trace(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import mock
import time

import pytest

from opik.message_processing.batching import batchers
from opik.message_processing import messages

NOT_USED = None


@pytest.mark.parametrize(
"message_batcher_class, batch_message_class",
[
(
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
messages.AddSpanFeedbackScoresBatchMessage,
),
(
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
messages.AddTraceFeedbackScoresBatchMessage,
),
],
)
def test_add_feedback_scores_batch_message_batcher__exactly_max_batch_size_reached__batch_is_flushed(
message_batcher_class,
batch_message_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
add_feedback_score_batch_messages = [
messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5]),
] # batcher doesn't care about the content

for feedback_scores_batch in add_feedback_score_batch_messages:
batcher.add(feedback_scores_batch)
assert batcher.is_empty()

flush_callback.assert_called_once_with(
batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False)
)


@pytest.mark.parametrize(
"message_batcher_class,batch_message_class",
[
(
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
messages.AddSpanFeedbackScoresBatchMessage,
),
(
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
messages.AddTraceFeedbackScoresBatchMessage,
),
],
)
def test_add_feedback_scores_batch_message_batcher__more_than_max_batch_size_items_added__one_batch_flushed__some_data_remains_in_batcher(
message_batcher_class,
batch_message_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
add_feedback_score_batch_messages = [
messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5, 6]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[7, 8]),
] # batcher doesn't care about the content

for feedback_scores_batch in add_feedback_score_batch_messages:
batcher.add(feedback_scores_batch)

assert not batcher.is_empty()
flush_callback.assert_called_once_with(
batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False)
)
flush_callback.reset_mock()

batcher.flush()
flush_callback.assert_called_once_with(
batch_message_class(batch=[6, 7, 8], supports_batching=False)
)
assert batcher.is_empty()


@pytest.mark.parametrize(
"message_batcher_class",
[
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
],
)
def test_add_feedback_scores_batch_message_batcher__batcher_doesnt_have_items__flush_is_called__flush_callback_NOT_called(
message_batcher_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
batcher.flush()
flush_callback.assert_not_called()


@pytest.mark.parametrize(
"message_batcher_class",
[
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
],
)
def test_add_feedback_scores_batch_message_batcher__ready_to_flush_returns_True__is_flush_interval_passed(
message_batcher_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5
FLUSH_INTERVAL = 0.1

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=FLUSH_INTERVAL,
)
assert not batcher.is_ready_to_flush()
time.sleep(0.1)
assert batcher.is_ready_to_flush()
batcher.flush()
assert not batcher.is_ready_to_flush()
time.sleep(0.1)
assert batcher.is_ready_to_flush()

0 comments on commit f553816

Please sign in to comment.