Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: return discarded messages as a list of message indices #75

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History
from aidial_assistant.chain.history import History, ScopedMessage
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
Expand Down Expand Up @@ -109,6 +109,25 @@ def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
)


def _create_history(
messages: list[ScopedMessage], plugins: list[PluginInfo]
) -> History:
plugin_descriptions = {
plugin.info.ai_plugin.name_for_model: plugin.info.open_api.info.description
or plugin.info.ai_plugin.description_for_human
for plugin in plugins
}
return History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=plugin_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=plugin_descriptions
),
scoped_messages=messages,
)


class AssistantApplication(ChatCompletion):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
Expand Down Expand Up @@ -204,21 +223,25 @@ def create_command(addon: PluginInfo):
or addon.info.ai_plugin.description_for_human
for addon in addons
}
scoped_messages = parse_history(request.messages)
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
addons=addon_descriptions
),
best_effort_template=MAIN_BEST_EFFORT_TEMPLATE.build(
addons=addon_descriptions
),
scoped_messages=parse_history(request.messages),
scoped_messages=scoped_messages,
)
discarded_messages: int | None = None
discarded_user_messages: set[int] | None = None
if request.max_prompt_tokens is not None:
original_size = history.user_message_count
history = await history.truncate(request.max_prompt_tokens, model)
truncated_size = history.user_message_count
discarded_messages = original_size - truncated_size
history, discarded_messages = await history.truncate(
model, request.max_prompt_tokens
)
discarded_user_messages = set(
scoped_messages[index].user_index
for index in discarded_messages
)
# TODO: else compare the history size to the max prompt tokens of the underlying model

choice = response.create_single_choice()
Expand All @@ -243,8 +266,8 @@ def create_command(addon: PluginInfo):
model.total_prompt_tokens, model.total_completion_tokens
)

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
if discarded_user_messages is not None:
response.set_discarded_messages(list(discarded_user_messages))

@staticmethod
async def _run_native_tools_chat(
Expand Down
136 changes: 73 additions & 63 deletions aidial_assistant/chain/history.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from enum import Enum
from typing import Tuple, cast

from jinja2 import Template
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from pydantic import BaseModel

from aidial_assistant.chain.command_result import (
Expand All @@ -26,6 +30,7 @@ class MessageScope(str, Enum):
class ScopedMessage(BaseModel):
scope: MessageScope = MessageScope.USER
message: ChatCompletionMessageParam
user_index: int


class History:
Expand All @@ -40,35 +45,32 @@ def __init__(
)
self.best_effort_template = best_effort_template
self.scoped_messages = scoped_messages
self._user_message_count = sum(
1
for message in scoped_messages
if message.scope == MessageScope.USER
)

def to_protocol_messages(self) -> list[ChatCompletionMessageParam]:
messages: list[ChatCompletionMessageParam] = []
for index, scoped_message in enumerate(self.scoped_messages):
scoped_message_iterator = iter(self.scoped_messages)
if self._is_first_system_message():
message = cast(
ChatCompletionSystemMessageParam,
next(scoped_message_iterator).message,
)
messages.append(
system_message(
self.assistant_system_message_template.render(
system_prefix=message["content"]
)
)
)
else:
messages.append(
system_message(self.assistant_system_message_template.render())
)

for scoped_message in scoped_message_iterator:
message = scoped_message.message
scope = scoped_message.scope

if index == 0:
if message["role"] == "system":
messages.append(
system_message(
self.assistant_system_message_template.render(
system_prefix=message["content"]
)
)
)
else:
messages.append(
system_message(
self.assistant_system_message_template.render()
)
)
messages.append(message)
elif scope == MessageScope.USER and message["role"] == "assistant":
if scope == MessageScope.USER and message["role"] == "assistant":
# Clients see replies in plain text, but the model should understand how to reply appropriately.
content = commands_to_text(
[
Expand Down Expand Up @@ -107,51 +109,59 @@ def to_best_effort_messages(
return messages

async def truncate(
self, max_prompt_tokens: int, model_client: ModelClient
) -> "History":
discarded_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
max_prompt_tokens,
self, model_client: ModelClient, max_prompt_tokens: int
) -> Tuple["History", list[int]]:
discarded_messages = await self._get_discarded_messages(
model_client, max_prompt_tokens
)

if discarded_messages > 0:
return History(
if not discarded_messages:
return self, []

discarded_messages_set = set(discarded_messages)
return (
History(
assistant_system_message_template=self.assistant_system_message_template,
best_effort_template=self.best_effort_template,
scoped_messages=self._skip_messages(discarded_messages),
)

return self

@property
def user_message_count(self) -> int:
return self._user_message_count

def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]:
messages: list[ScopedMessage] = []
current_message = self.scoped_messages[0]
message_iterator = iter(self.scoped_messages)
for _ in range(discarded_messages):
current_message = next(message_iterator)
while current_message.message["role"] == "system":
# System messages should be kept in the history
messages.append(current_message)
current_message = next(message_iterator)
scoped_messages=[
scoped_message
for index, scoped_message in enumerate(self.scoped_messages)
if index not in discarded_messages_set
],
),
discarded_messages,
)

if current_message.scope == MessageScope.INTERNAL:
while current_message.scope == MessageScope.INTERNAL:
current_message = next(message_iterator)
async def _get_discarded_messages(
self, model_client: ModelClient, max_prompt_tokens: int
) -> list[int]:
discarded_protocol_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
max_prompt_tokens,
)

# Internal messages (i.e. addon requests/responses) are always followed by an assistant reply
assert (
current_message.message["role"] == "assistant"
), "Internal messages must be followed by an assistant reply."
if discarded_protocol_messages:
discarded_protocol_messages.sort()
discarded_messages = (
discarded_protocol_messages
if self._is_first_system_message()
else [index - 1 for index in discarded_protocol_messages]
)
user_indices = set(
self.scoped_messages[index].user_index
for index in discarded_messages
)

remaining_messages = list(message_iterator)
assert (
len(remaining_messages) > 0
), "No user messages left after history truncation."
return [
index
for index, scoped_message in enumerate(self.scoped_messages)
if scoped_message.user_index in user_indices
]

messages += remaining_messages
return discarded_protocol_messages

return messages
def _is_first_system_message(self) -> bool:
return (
len(self.scoped_messages) > 0
and self.scoped_messages[0].message["role"] == "system"
)
4 changes: 3 additions & 1 deletion aidial_assistant/commands/run_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def create_command(op: APIOperation):
best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build(
api_schema=api_schema
),
scoped_messages=[ScopedMessage(message=user_message(query))],
scoped_messages=[
ScopedMessage(message=user_message(query), user_index=0)
],
)

chat = CommandChain(
Expand Down
41 changes: 31 additions & 10 deletions aidial_assistant/model/model_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC
from itertools import islice
from typing import Any, AsyncIterator, List

from aidial_sdk.utils.merge_chunks import merge
Expand All @@ -16,7 +17,7 @@ class ReasonLengthException(Exception):


class ExtraResultsCallback:
def on_discarded_messages(self, discarded_messages: int):
def on_discarded_messages(self, discarded_messages: list[int]):
pass

def on_prompt_tokens(self, prompt_tokens: int):
Expand All @@ -36,6 +37,21 @@ async def _flush_stream(stream: AsyncIterator[str]):
pass


def _discarded_messages_count_to_indices(
messages: list[ChatCompletionMessageParam], discarded_messages: int
) -> list[int]:
return list(
islice(
(
i
for i, message in enumerate(messages)
if message["role"] != "system"
),
discarded_messages,
)
)


class ModelClient(ABC):
def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]):
self.client = client
Expand Down Expand Up @@ -70,12 +86,16 @@ async def agenerate(
extra_results_callback.on_prompt_tokens(prompt_tokens)

if extra_results_callback:
discarded_messages: int | None = chunk_dict.get(
discarded_messages: int | list[int] | None = chunk_dict.get(
"statistics", {}
).get("discarded_messages")
if discarded_messages is not None:
extra_results_callback.on_discarded_messages(
discarded_messages
_discarded_messages_count_to_indices(
messages, discarded_messages
)
if isinstance(discarded_messages, int)
else discarded_messages
)

choice = chunk.choices[0]
Expand Down Expand Up @@ -128,15 +148,16 @@ def on_prompt_tokens(self, prompt_tokens: int):
return callback.token_count

# TODO: Use a dedicated endpoint for discarded_messages.
# https://github.com/epam/ai-dial-assistant/issues/39
async def get_discarded_messages(
self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int
) -> int:
) -> list[int]:
class DiscardedMessagesCallback(ExtraResultsCallback):
def __init__(self):
self.message_count: int | None = None
self.discarded_messages: list[int] | None = None

def on_discarded_messages(self, discarded_messages: int):
self.message_count = discarded_messages
def on_discarded_messages(self, discarded_messages: list[int]):
self.discarded_messages = discarded_messages

callback = DiscardedMessagesCallback()
await _flush_stream(
Expand All @@ -147,10 +168,10 @@ def on_discarded_messages(self, discarded_messages: int):
max_tokens=1,
)
)
if callback.message_count is None:
raise Exception("No message count received.")
if callback.discarded_messages is None:
raise Exception("Discarded messages were not provided.")

return callback.message_count
return callback.discarded_messages

@property
def total_prompt_tokens(self) -> int:
Expand Down
Loading
Loading