Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OPIK-148] Implement batching for feedback scores, update and add new tests #1008

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@ def __init__(
self._last_time_flush_callback_called: float = time.time()
self._lock = threading.RLock()

def add(self, message: messages.BaseMessage) -> None:
with self._lock:
self._accumulated_messages.append(message)
if len(self._accumulated_messages) == self._max_batch_size:
self.flush()

def flush(self) -> None:
with self._lock:
if len(self._accumulated_messages) > 0:
Expand All @@ -47,3 +41,10 @@ def is_empty(self) -> bool:

@abc.abstractmethod
def _create_batch_from_accumulated_messages(self) -> messages.BaseMessage: ...

@abc.abstractmethod
def add(self, message: messages.BaseMessage) -> None:
with self._lock:
self._accumulated_messages.append(message)
if len(self._accumulated_messages) == self._max_batch_size:
self.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def stop(self) -> None:
self._flushing_thread.close()

def message_supports_batching(self, message: messages.BaseMessage) -> bool:
if hasattr(message, "supports_batching"):
return message.supports_batching

return message.__class__ in self._message_to_batcher_mapping

def process_message(self, message: messages.BaseMessage) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
CREATE_SPANS_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
CREATE_SPANS_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000

FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS = 1.0
FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE = 1000


def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManager:
create_span_message_batcher_ = batchers.CreateSpanMessageBatcher(
Expand All @@ -24,11 +27,25 @@ def create_batch_manager(message_queue: queue.Queue) -> batch_manager.BatchManag
flush_callback=message_queue.put,
)

add_span_feedback_scores_batch_message_batcher = batchers.AddSpanFeedbackScoresBatchMessageBatcher(
flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

add_trace_feedback_scores_batch_message_batcher = batchers.AddTraceFeedbackScoresBatchMessageBatcher(
flush_interval_seconds=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_FLUSH_INTERVAL_SECONDS,
max_batch_size=FEEDBACK_SCORES_BATCH_MESSAGE_BATCHER_MAX_BATCH_SIZE,
flush_callback=message_queue.put,
)

message_to_batcher_mapping: Dict[
Type[messages.BaseMessage], base_batcher.BaseBatcher
] = {
messages.CreateSpanMessage: create_span_message_batcher_,
messages.CreateTraceMessage: create_trace_message_batcher_,
messages.AddSpanFeedbackScoresBatchMessage: add_span_feedback_scores_batch_message_batcher,
messages.AddTraceFeedbackScoresBatchMessage: add_trace_feedback_scores_batch_message_batcher,
}

batch_manager_ = batch_manager.BatchManager(
Expand Down
70 changes: 70 additions & 0 deletions sdks/python/src/opik/message_processing/batching/batchers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from . import base_batcher
from .. import messages

Expand All @@ -8,9 +10,77 @@ def _create_batch_from_accumulated_messages(
) -> messages.CreateSpansBatchMessage:
return messages.CreateSpansBatchMessage(batch=self._accumulated_messages) # type: ignore

def add(self, message: messages.CreateSpansBatchMessage) -> None: # type: ignore
return super().add(message)


class CreateTraceMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages(
self,
) -> messages.CreateTraceBatchMessage:
return messages.CreateTraceBatchMessage(batch=self._accumulated_messages) # type: ignore

def add(self, message: messages.CreateTraceBatchMessage) -> None: # type: ignore
return super().add(message)


class BaseAddFeedbackScoresBatchMessageBatcher(base_batcher.BaseBatcher):
def _create_batch_from_accumulated_messages( # type: ignore
self,
) -> Union[
messages.AddSpanFeedbackScoresBatchMessage,
messages.AddTraceFeedbackScoresBatchMessage,
]:
return super()._create_batch_from_accumulated_messages() # type: ignore

def add( # type: ignore
self,
message: Union[
messages.AddSpanFeedbackScoresBatchMessage,
messages.AddTraceFeedbackScoresBatchMessage,
],
) -> None:
with self._lock:
new_messages = message.batch
n_new_messages = len(new_messages)
n_accumulated_messages = len(self._accumulated_messages)

if n_new_messages + n_accumulated_messages >= self._max_batch_size:
free_space_in_accumulator = (
self._max_batch_size - n_accumulated_messages
)

messages_that_fit_in_batch = new_messages[:free_space_in_accumulator]
messages_that_dont_fit_in_batch = new_messages[
free_space_in_accumulator:
]

self._accumulated_messages += messages_that_fit_in_batch
new_messages = messages_that_dont_fit_in_batch
self.flush()

self._accumulated_messages += new_messages


class AddSpanFeedbackScoresBatchMessageBatcher(
BaseAddFeedbackScoresBatchMessageBatcher
):
def _create_batch_from_accumulated_messages(
self,
) -> messages.AddSpanFeedbackScoresBatchMessage: # type: ignore
return messages.AddSpanFeedbackScoresBatchMessage(
batch=self._accumulated_messages, # type: ignore
supports_batching=False,
)


class AddTraceFeedbackScoresBatchMessageBatcher(
BaseAddFeedbackScoresBatchMessageBatcher
):
def _create_batch_from_accumulated_messages(
self,
) -> messages.AddTraceFeedbackScoresBatchMessage: # type: ignore
return messages.AddTraceFeedbackScoresBatchMessage(
batch=self._accumulated_messages, # type: ignore
supports_batching=False,
)
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _process_add_span_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Batch of spans feedbacks scores request: %s", scores)
LOGGER.debug("Add spans feedbacks scores request of size: %d", len(scores))

self._rest_client.spans.score_batch_of_spans(
scores=scores,
Expand All @@ -152,7 +152,7 @@ def _process_add_trace_feedback_scores_batch_message(
for score_message in message.batch
]

LOGGER.debug("Batch of traces feedbacks scores request: %s", scores)
LOGGER.debug("Add traces feedbacks scores request: %d", len(scores))

self._rest_client.traces.score_batch_of_traces(
scores=scores,
Expand Down
17 changes: 14 additions & 3 deletions sdks/python/src/opik/message_processing/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,24 @@ class FeedbackScoreMessage(BaseMessage):


@dataclasses.dataclass
class AddTraceFeedbackScoresBatchMessage(BaseMessage):
class AddFeedbackScoresBatchMessage(BaseMessage):
batch: List[FeedbackScoreMessage]
supports_batching: bool = True

def as_payload_dict(self) -> Dict[str, Any]:
data = super().as_payload_dict()
data.pop("supports_batching")
return data


@dataclasses.dataclass
class AddSpanFeedbackScoresBatchMessage(BaseMessage):
batch: List[FeedbackScoreMessage]
class AddTraceFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage):
pass


@dataclasses.dataclass
class AddSpanFeedbackScoresBatchMessage(AddFeedbackScoresBatchMessage):
pass


@dataclasses.dataclass
Expand Down
15 changes: 14 additions & 1 deletion sdks/python/tests/e2e/test_feedback_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ def f_outer():
"value": 0.75,
"category_name": "trace-score-category",
"reason": "trace-score-reason",
}
},
{
"name": "trace-feedback-score-2",
"value": 0.5,
"category_name": "trace-score-category-2",
"reason": "trace-score-reason-2",
},
]
)

Expand All @@ -193,6 +199,13 @@ def f_outer():
"category_name": "trace-score-category",
"reason": "trace-score-reason",
},
{
"id": ID_STORAGE["f_outer-trace-id"],
"name": "trace-feedback-score-2",
"value": 0.5,
"category_name": "trace-score-category-2",
"reason": "trace-score-reason-2",
},
]

verifiers.verify_trace(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import mock
import time

import pytest

from opik.message_processing.batching import batchers
from opik.message_processing import messages

NOT_USED = None


@pytest.mark.parametrize(
"message_batcher_class, batch_message_class",
[
(
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
messages.AddSpanFeedbackScoresBatchMessage,
),
(
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
messages.AddTraceFeedbackScoresBatchMessage,
),
],
)
def test_add_feedback_scores_batch_message_batcher__exactly_max_batch_size_reached__batch_is_flushed(
message_batcher_class,
batch_message_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
add_feedback_score_batch_messages = [
messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5]),
] # batcher doesn't care about the content

for feedback_scores_batch in add_feedback_score_batch_messages:
batcher.add(feedback_scores_batch)
assert batcher.is_empty()

flush_callback.assert_called_once_with(
batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False)
)


@pytest.mark.parametrize(
"message_batcher_class,batch_message_class",
[
(
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
messages.AddSpanFeedbackScoresBatchMessage,
),
(
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
messages.AddTraceFeedbackScoresBatchMessage,
),
],
)
def test_add_feedback_scores_batch_message_batcher__more_than_max_batch_size_items_added__one_batch_flushed__some_data_remains_in_batcher(
message_batcher_class,
batch_message_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
add_feedback_score_batch_messages = [
messages.AddSpanFeedbackScoresBatchMessage(batch=[1, 2]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[3, 4, 5, 6]),
messages.AddSpanFeedbackScoresBatchMessage(batch=[7, 8]),
] # batcher doesn't care about the content

for feedback_scores_batch in add_feedback_score_batch_messages:
batcher.add(feedback_scores_batch)

assert not batcher.is_empty()
flush_callback.assert_called_once_with(
batch_message_class(batch=[1, 2, 3, 4, 5], supports_batching=False)
)
flush_callback.reset_mock()

batcher.flush()
flush_callback.assert_called_once_with(
batch_message_class(batch=[6, 7, 8], supports_batching=False)
)
assert batcher.is_empty()


@pytest.mark.parametrize(
"message_batcher_class",
[
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
],
)
def test_add_feedback_scores_batch_message_batcher__batcher_doesnt_have_items__flush_is_called__flush_callback_NOT_called(
message_batcher_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=NOT_USED,
)

assert batcher.is_empty()
batcher.flush()
flush_callback.assert_not_called()


@pytest.mark.parametrize(
"message_batcher_class",
[
batchers.AddSpanFeedbackScoresBatchMessageBatcher,
batchers.AddTraceFeedbackScoresBatchMessageBatcher,
],
)
def test_add_feedback_scores_batch_message_batcher__ready_to_flush_returns_True__is_flush_interval_passed(
message_batcher_class,
):
flush_callback = mock.Mock()

MAX_BATCH_SIZE = 5
FLUSH_INTERVAL = 0.1

batcher = message_batcher_class(
max_batch_size=MAX_BATCH_SIZE,
flush_callback=flush_callback,
flush_interval_seconds=FLUSH_INTERVAL,
)
assert not batcher.is_ready_to_flush()
time.sleep(0.1)
assert batcher.is_ready_to_flush()
batcher.flush()
assert not batcher.is_ready_to_flush()
time.sleep(0.1)
assert batcher.is_ready_to_flush()
Loading