From f553816242b750283af6854421c329e0bec7cc9d Mon Sep 17 00:00:00 2001 From: Aliaksandr Kuzmik Date: Thu, 9 Jan 2025 12:14:15 +0100 Subject: [PATCH] Implement batching for feedback scores, update and add new tests --- .../batching/base_batcher.py | 13 +- .../batching/batch_manager.py | 3 + .../batching/batch_manager_constuctors.py | 17 ++ .../message_processing/batching/batchers.py | 70 ++++++++ .../message_processing/message_processors.py | 4 +- .../src/opik/message_processing/messages.py | 17 +- sdks/python/tests/e2e/test_feedback_scores.py | 15 +- .../batching/test_feedback_scores_batcher.py | 155 ++++++++++++++++++ 8 files changed, 282 insertions(+), 12 deletions(-) create mode 100644 sdks/python/tests/unit/message_processing/batching/test_feedback_scores_batcher.py diff --git a/sdks/python/src/opik/message_processing/batching/base_batcher.py b/sdks/python/src/opik/message_processing/batching/base_batcher.py index 77926e50ec..6d6bde2b90 100644 --- a/sdks/python/src/opik/message_processing/batching/base_batcher.py +++ b/sdks/python/src/opik/message_processing/batching/base_batcher.py @@ -21,12 +21,6 @@ def __init__( self._last_time_flush_callback_called: float = time.time() self._lock = threading.RLock() - def add(self, message: messages.BaseMessage) -> None: - with self._lock: - self._accumulated_messages.append(message) - if len(self._accumulated_messages) == self._max_batch_size: - self.flush() - def flush(self) -> None: with self._lock: if len(self._accumulated_messages) > 0: @@ -47,3 +41,10 @@ def is_empty(self) -> bool: @abc.abstractmethod def _create_batch_from_accumulated_messages(self) -> messages.BaseMessage: ... + + @abc.abstractmethod + def add(self, message: messages.BaseMessage) -> None: + with self._lock: + self._accumulated_messages.append(message) + if len(self._accumulated_messages) == self._max_batch_size: + self.flush() diff --git a/sdks/python/src/opik/message_processing/batching/batch_manager.py b/sdks/python/src/opik/message_processing/batching/batch_manager.py index cb5f489a14..81e46dc325 100644 --- a/sdks/python/src/opik/message_processing/batching/batch_manager.py +++ b/sdks/python/src/opik/message_processing/batching/batch_manager.py @@ -23,6 +23,9 @@ def stop(self) -> None: self._flushing_thread.close() def message_supports_batching(self, message: messages.BaseMessage) -> bool: + if hasattr(message, "supports_batching"): + return message.supports_batching + return message.__class__ in self._message_to_batcher_mapping def process_message(self, message: messages.BaseMessage) -> None: diff --git a/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py b/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py index ce345285f5..9b9c5b58f1 100644 --- a/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py +++ b/sdks/python/src/opik/message_processing/batching/batch_manager_constuctors.py @@ -10,6 +10,9 @@ CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0 CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000 +FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0 +FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000 + def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager: create_span_message_batcher_ = batchers.CreateSpanMessageBatcher( @@ -24,11 +27,25 @@ def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManag flush_callback=message_queue.put, ) + add_span_feedback_scores_batch_message_batcher = batchers.AddSpanFeedbackScoresBatchMessageBatcher( + flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS, + max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE, + flush_callback=message_queue.put, + ) + + add_trace_feedback_scores_batch_message_batcher = batchers.AddTraceFeedbackScoresBatchMessageBatcher( + flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS, + max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE, + flush_callback=message_queue.put, + ) + message_to_batcher_mapping: Dict[ Type[messages.BaseMessage], base_batcher.BaseBatcher ] = { messages.CreateSpanMessage: create_span_message_batcher_, messages.CreateTraceMessage: create_trace_message_batcher_, + messages.AddSpanFeedbackScoresBatchMessage: add_span_feedback_scores_batch_message_batcher, + messages.AddTraceFeedbackScoresBatchMessage: add_trace_feedback_scores_batch_message_batcher, } batch_manager_ = batch_manager.BatchManager( diff --git a/sdks/python/src/opik/message_processing/batching/batchers.py b/sdks/python/src/opik/message_processing/batching/batchers.py index 56bc4b1875..156adcfefc 100644 --- a/sdks/python/src/opik/message_processing/batching/batchers.py +++ b/sdks/python/src/opik/message_processing/batching/batchers.py @@ -1,3 +1,5 @@ +from typing import Union + from . import base_batcher from .. import messages @@ -8,9 +10,77 @@ def _create_batch_from_accumulated_messages( ) -> messages.CreateSpansBatchMessage: return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore + def add(self, message: messages.CreateSpansBatchMessage) -> None: # type: ignore + return super().add(message) + class CreateTraceMessageBatcher(base_batcher.BaseBatcher): def _create_batch_from_accumulated_messages( self, ) -> messages.CreateTraceBatchMessage: return messages.CreateTraceBatchMessage(batch=self._accumulated_messages) # type: ignore + + def add(self, message: messages.CreateTraceBatchMessage) -> None: # type: ignore + return super().add(message) + + +class BaseAddFeedbackScoresBatchMessageBatcher(base_batcher.BaseBatcher): + def _create_batch_from_accumulated_messages( # type: ignore + self, + ) -> Union[ + messages.AddSpanFeedbackScoresBatchMessage, + messages.AddTraceFeedbackScoresBatchMessage, + ]: + return super()._create_batch_from_accumulated_messages() # type: ignore + + def add( # type: ignore + self, + message: Union[ + messages.AddSpanFeedbackScoresBatchMessage, + messages.AddTraceFeedbackScoresBatchMessage, + ], + ) -> None: + with self._lock: + new_messages = message.batch + n_new_messages = len(new_messages) + n_accumulated_messages = len(self._accumulated_messages) + + if n_new_messages + n_accumulated_messages >= self._max_batch_size: + free_space_in_accumulator = ( + self._max_batch_size - n_accumulated_messages + ) + + messages_that_fit_in_batch = new_messages[:free_space_in_accumulator] + messages_that_dont_fit_in_batch = new_messages[ + free_space_in_accumulator: + ] + + self._accumulated_messages += messages_that_fit_in_batch + new_messages = messages_that_dont_fit_in_batch + self.flush() + + self._accumulated_messages += new_messages + + +class AddSpanFeedbackScoresBatchMessageBatcher( + BaseAddFeedbackScoresBatchMessageBatcher +): + def _create_batch_from_accumulated_messages( + self, + ) -> messages.AddSpanFeedbackScoresBatchMessage: # type: ignore + return messages.AddSpanFeedbackScoresBatchMessage( + batch=self._accumulated_messages, # type: ignore + supports_batching=False, + ) + + +class AddTraceFeedbackScoresBatchMessageBatcher( + BaseAddFeedbackScoresBatchMessageBatcher +): + def _create_batch_from_accumulated_messages( + self, + ) -> messages.AddTraceFeedbackScoresBatchMessage: # type: ignore + return messages.AddTraceFeedbackScoresBatchMessage( + batch=self._accumulated_messages, # type: ignore + supports_batching=False, + ) diff --git a/sdks/python/src/opik/message_processing/message_processors.py b/sdks/python/src/opik/message_processing/message_processors.py index f3c757124b..df857934f0 100644 --- a/sdks/python/src/opik/message_processing/message_processors.py +++ b/sdks/python/src/opik/message_processing/message_processors.py @@ -137,7 +137,7 @@ def _process_add_span_feedback_scores_batch_message( for score_message in message.batch ] - LOGGER.debug("Batch of spans feedbacks scores request: %s", scores) + LOGGER.debug("Add spans feedbacks scores request of size: %d", len(scores)) self._rest_client.spans.score_batch_of_spans( scores=scores, @@ -152,7 +152,7 @@ def _process_add_trace_feedback_scores_batch_message( for score_message in message.batch ] - LOGGER.debug("Batch of traces feedbacks scores request: %s", scores) + LOGGER.debug("Add traces feedbacks scores request: %d", len(scores)) self._rest_client.traces.score_batch_of_traces( scores=scores, diff --git a/sdks/python/src/opik/message_processing/messages.py b/sdks/python/src/opik/message_processing/messages.py index 940d91513c..d8072a6c58 100644 --- a/sdks/python/src/opik/message_processing/messages.py +++ b/sdks/python/src/opik/message_processing/messages.py @@ -108,13 +108,24 @@ class FeedbackScoreMessage(BaseMessage): @dataclasses.dataclass -class AddTraceFeedbackScoresBatchMessage(BaseMessage): +class AddFeedbackScoresBatchMessage(BaseMessage): batch: List[FeedbackScoreMessage] + supports_batching: bool = True + + def as_payload_dict(self) -> Dict[str, Any]: + data = super().as_payload_dict() + data.pop("supports_batching") + return data @dataclasses.dataclass -class AddSpanFeedbackScoresBatchMessage(BaseMessage): - batch: List[FeedbackScoreMessage] +class AddTraceFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage): + pass + + +@dataclasses.dataclass +class AddSpanFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage): + pass @dataclasses.dataclass diff --git a/sdks/python/tests/e2e/test_feedback_scores.py b/sdks/python/tests/e2e/test_feedback_scores.py index eb8b058231..58845412ea 100644 --- a/sdks/python/tests/e2e/test_feedback_scores.py +++ b/sdks/python/tests/e2e/test_feedback_scores.py @@ -168,7 +168,13 @@ def f_outer(): "value": 0.75, "category_name": "trace-score-category", "reason": "trace-score-reason", - } + }, + { + "name": "trace-feedback-score-2", + "value": 0.5, + "category_name": "trace-score-category-2", + "reason": "trace-score-reason-2", + }, ] ) @@ -193,6 +199,13 @@ def f_outer(): "category_name": "trace-score-category", "reason": "trace-score-reason", }, + { + "id": ID_STORAGE["f_outer-trace-id"], + "name": "trace-feedback-score-2", + "value": 0.5, + "category_name": "trace-score-category-2", + "reason": "trace-score-reason-2", + }, ] verifiers.verify_trace( diff --git a/sdks/python/tests/unit/message_processing/batching/test_feedback_scores_batcher.py b/sdks/python/tests/unit/message_processing/batching/test_feedback_scores_batcher.py new file mode 100644 index 0000000000..1043b76f72 --- /dev/null +++ b/sdks/python/tests/unit/message_processing/batching/test_feedback_scores_batcher.py @@ -0,0 +1,155 @@ +import mock +import time + +import pytest + +from opik.message_processing.batching import batchers +from opik.message_processing import messages + +NOT_USED = None + + +@pytest.mark.parametrize( + "message_batcher_class, batch_message_class", + [ + ( + batchers.AddSpanFeedbackScoresBatchMessageBatcher, + messages.AddSpanFeedbackScoresBatchMessage, + ), + ( + batchers.AddTraceFeedbackScoresBatchMessageBatcher, + messages.AddTraceFeedbackScoresBatchMessage, + ), + ], +) +def test_add_feedback_scores_batch_message_batcher__exactly_max_batch_size_reached__batch_is_flushed( + message_batcher_class, + batch_message_class, +): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = message_batcher_class( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + add_feedback_score_batch_messages = [ + messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]), + messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5]), + ] # batcher doesn't care about the content + + for feedback_scores_batch in add_feedback_score_batch_messages: + batcher.add(feedback_scores_batch) + assert batcher.is_empty() + + flush_callback.assert_called_once_with( + batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False) + ) + + +@pytest.mark.parametrize( + "message_batcher_class,batch_message_class", + [ + ( + batchers.AddSpanFeedbackScoresBatchMessageBatcher, + messages.AddSpanFeedbackScoresBatchMessage, + ), + ( + batchers.AddTraceFeedbackScoresBatchMessageBatcher, + messages.AddTraceFeedbackScoresBatchMessage, + ), + ], +) +def test_add_feedback_scores_batch_message_batcher__more_than_max_batch_size_items_added__one_batch_flushed__some_data_remains_in_batcher( + message_batcher_class, + batch_message_class, +): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = message_batcher_class( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + add_feedback_score_batch_messages = [ + messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]), + messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5, 6]), + messages.AddSpanFeedbackScoresBatchMessage(batch=[7, 8]), + ] # batcher doesn't care about the content + + for feedback_scores_batch in add_feedback_score_batch_messages: + batcher.add(feedback_scores_batch) + + assert not batcher.is_empty() + flush_callback.assert_called_once_with( + batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False) + ) + flush_callback.reset_mock() + + batcher.flush() + flush_callback.assert_called_once_with( + batch_message_class(batch=[6, 7, 8], supports_batching=False) + ) + assert batcher.is_empty() + + +@pytest.mark.parametrize( + "message_batcher_class", + [ + batchers.AddSpanFeedbackScoresBatchMessageBatcher, + batchers.AddTraceFeedbackScoresBatchMessageBatcher, + ], +) +def test_add_feedback_scores_batch_message_batcher__batcher_doesnt_have_items__flush_is_called__flush_callback_NOT_called( + message_batcher_class, +): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + + batcher = message_batcher_class( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=NOT_USED, + ) + + assert batcher.is_empty() + batcher.flush() + flush_callback.assert_not_called() + + +@pytest.mark.parametrize( + "message_batcher_class", + [ + batchers.AddSpanFeedbackScoresBatchMessageBatcher, + batchers.AddTraceFeedbackScoresBatchMessageBatcher, + ], +) +def test_add_feedback_scores_batch_message_batcher__ready_to_flush_returns_True__is_flush_interval_passed( + message_batcher_class, +): + flush_callback = mock.Mock() + + MAX_BATCH_SIZE = 5 + FLUSH_INTERVAL = 0.1 + + batcher = message_batcher_class( + max_batch_size=MAX_BATCH_SIZE, + flush_callback=flush_callback, + flush_interval_seconds=FLUSH_INTERVAL, + ) + assert not batcher.is_ready_to_flush() + time.sleep(0.1) + assert batcher.is_ready_to_flush() + batcher.flush() + assert not batcher.is_ready_to_flush() + time.sleep(0.1) + assert batcher.is_ready_to_flush()