diff --git a/examples/conversations/recontextualize_message.py b/examples/conversations/recontextualize_message.py new file mode 100644 index 00000000..23248f81 --- /dev/null +++ b/examples/conversations/recontextualize_message.py @@ -0,0 +1,52 @@ +""" +Ragbits Conversations Example: Recontextualize Last Message + +This example demonstrates how to use the `StandaloneMessageCompressor` compressor to recontextualize +the last message in a conversation history. +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-conversations", +# ] +# /// + +import asyncio + +from ragbits.conversations.history.compressors.llm import StandaloneMessageCompressor +from ragbits.core.llms.litellm import LiteLLM +from ragbits.core.prompt import ChatFormat + +# Example conversation history +conversation: ChatFormat = [ + {"role": "user", "content": "Who's working on Friday?"}, + {"role": "assistant", "content": "Jim"}, + {"role": "user", "content": "Where is he based?"}, + {"role": "assistant", "content": "At our California Head Office"}, + {"role": "user", "content": "Is he a senior staff member?"}, + {"role": "assistant", "content": "Yes, he's a senior manager"}, + {"role": "user", "content": "What's his phone number (including the prefix for his state)?"}, +] + + +async def main() -> None: + """ + Main function to demonstrate the StandaloneMessageCompressor compressor. + """ + # Initialize the LiteLLM client + llm = LiteLLM("gpt-4o") + + # Initialize the StandaloneMessageCompressor compressor + compressor = StandaloneMessageCompressor(llm, history_len=10) + + # Compress the conversation history + recontextualized_message = await compressor.compress(conversation) + + # Print the recontextualized message + print("Recontextualized Message:") + print(recontextualized_message) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/ragbits-conversations/README.md b/packages/ragbits-conversations/README.md new file mode 100644 index 00000000..44407896 --- /dev/null +++ b/packages/ragbits-conversations/README.md @@ -0,0 +1 @@ +# Ragbits Conversation diff --git a/packages/ragbits-conversations/pyproject.toml b/packages/ragbits-conversations/pyproject.toml new file mode 100644 index 00000000..7b3ff43b --- /dev/null +++ b/packages/ragbits-conversations/pyproject.toml @@ -0,0 +1,63 @@ +[project] +name = "ragbits-conversations" +version = "0.6.0" +description = "Building blocks for rapid development of GenAI applications" +readme = "README.md" +requires-python = ">=3.10" +license = "MIT" +authors = [ + { name = "deepsense.ai", email = "ragbits@deepsense.ai"} +] +keywords = [ + "Retrieval Augmented Generation", + "RAG", + "Large Language Models", + "LLMs", + "Generative AI", + "GenAI", + "Prompt Management" +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", +] +dependencies = [] + +[project.urls] +"Homepage" = "https://github.com/deepsense-ai/ragbits" +"Bug Reports" = "https://github.com/deepsense-ai/ragbits/issues" +"Documentation" = "https://ragbits.deepsense.ai/" +"Source" = "https://github.com/deepsense-ai/ragbits" + +[project.optional-dependencies] +[tool.uv] +dev-dependencies = [ + "pre-commit~=3.8.0", + "pytest~=8.3.3", + "pytest-cov~=5.0.0", + "pytest-asyncio~=0.24.0", + "pip-licenses>=4.0.0,<5.0.0" +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/ragbits"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" diff --git a/packages/ragbits-conversations/src/ragbits/conversations/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py new file mode 100644 index 00000000..66ac169a --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/__init__.py @@ -0,0 +1,4 @@ +from .base import ConversationHistoryCompressor +from .llm import StandaloneMessageCompressor + +__all__ = ["ConversationHistoryCompressor", "StandaloneMessageCompressor"] diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py new file mode 100644 index 00000000..9b5816e1 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/base.py @@ -0,0 +1,31 @@ +from abc import ABC, abstractmethod +from typing import ClassVar + +from ragbits.conversations.history import compressors +from ragbits.core.prompt.base import ChatFormat +from ragbits.core.utils.config_handling import WithConstructionConfig + + +class ConversationHistoryCompressor(WithConstructionConfig, ABC): + """ + An abstract class for conversation history compressors, + i.e. class that takes the entire conversation history + and returns a single string representation of it. + + The exact logic of what the string should include and represent + depends on the specific implementation. + + Usually used to provide LLM additional context from the conversation history. + """ + + default_module: ClassVar = compressors + configuration_key: ClassVar = "history_compressor" + + @abstractmethod + async def compress(self, conversation: ChatFormat) -> str: + """ + Compresses the conversation history to a single string. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + """ diff --git a/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py new file mode 100644 index 00000000..70797609 --- /dev/null +++ b/packages/ragbits-conversations/src/ragbits/conversations/history/compressors/llm.py @@ -0,0 +1,86 @@ +from pydantic import BaseModel + +from ragbits.conversations.history.compressors import ConversationHistoryCompressor +from ragbits.core.llms.base import LLM +from ragbits.core.prompt import ChatFormat, Prompt + + +class LastMessageAndHistory(BaseModel): + """ + A class representing the last message and the history of messages. + """ + + last_message: str + history: list[str] + + +class StandaloneMessageCompressorPrompt(Prompt[LastMessageAndHistory, str]): + """ + A prompt for recontextualizing the last message in the history. + """ + + system_prompt = """ + Given a new message and a history of the conversation, create a standalone version of the message. + If the message references any context from history, it should be added to the message itself. + Return only the recontextualized message. + Do NOT return the history, do NOT answer the question, and do NOT add context irrelevant to the message. + """ + + user_prompt = """ + Message: + {{ last_message }} + + History: + {% for message in history %} + * {{ message }} + {% endfor %} + """ + + +class StandaloneMessageCompressor(ConversationHistoryCompressor): + """ + A compressor that uses LLM to recontextualize the last message in the history, + i.e. create a standalone version of the message that includes necessary context. + """ + + def __init__(self, llm: LLM, history_len: int = 5, prompt: type[Prompt[LastMessageAndHistory, str]] | None = None): + """ + Initialize the StandaloneMessageCompressor compressor with a LLM. + + Args: + llm: A LLM instance to handle recontextualizing the last message. + history_len: The number of previous messages to include in the history. + prompt: The prompt to use for recontextualizing the last message. + """ + self._llm = llm + self._history_len = history_len + self._prompt = prompt or StandaloneMessageCompressorPrompt + + async def compress(self, conversation: ChatFormat) -> str: + """ + Contextualize the last message in the conversation history. + + Args: + conversation: List of dicts with "role" and "content" keys, representing the chat history so far. + The most recent message should be from the user. + """ + if len(conversation) == 0: + raise ValueError("Conversation history is empty.") + + last_message = conversation[-1] + if last_message["role"] != "user": + raise ValueError("StandaloneMessageCompressor expects the last message to be from the user.") + + # Only include "user" and "assistant" messages in the history + other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]] + + if not other_messages: + # No history to use for recontextualization, simply return the user message + return last_message["content"] + + history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]] + + input_data = LastMessageAndHistory(last_message=last_message["content"], history=history) + prompt = self._prompt(input_data) + response = await self._llm.generate(prompt) + return response diff --git a/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py new file mode 100644 index 00000000..715d8828 --- /dev/null +++ b/packages/ragbits-conversations/tests/unit/history/test_llm_compressor.py @@ -0,0 +1,111 @@ +import pytest + +from ragbits.conversations.history.compressors.llm import LastMessageAndHistory, StandaloneMessageCompressor +from ragbits.core.llms.mock import MockLLM, MockLLMOptions +from ragbits.core.prompt import ChatFormat +from ragbits.core.prompt.prompt import Prompt + + +class MockPrompt(Prompt[LastMessageAndHistory, str]): + user_prompt = "mock prompt" + + +async def test_messages_included(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + ] + llm = MockLLM(default_options=MockLLMOptions(response="some answer")) + compressor = StandaloneMessageCompressor(llm) + answer = await compressor.compress(conversation) + assert answer == "some answer" + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + assert "foo1" in content + assert "foo2" in content + assert "foo3" in content + + +async def test_no_messages(): + conversation: ChatFormat = [] + compressor = StandaloneMessageCompressor(MockLLM()) + + with pytest.raises(ValueError): + await compressor.compress(conversation) + + +async def test_last_message_not_user(): + conversation: ChatFormat = [ + {"role": "assistant", "content": "foo2"}, + ] + compressor = StandaloneMessageCompressor(MockLLM()) + + with pytest.raises(ValueError): + await compressor.compress(conversation) + + +async def test_history_len(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + {"role": "user", "content": "foo4"}, + {"role": "user", "content": "foo5"}, + ] + llm = MockLLM() + compressor = StandaloneMessageCompressor(llm, history_len=3) + await compressor.compress(conversation) + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + + # The rephrased message should be included + assert "foo5" in content + + # Three previous messages should be included + assert "foo2" in content + assert "foo3" in content + assert "foo4" in content + + # Earlier messages should not be included + assert "foo1" not in content + + +async def test_only_user_and_assistant_messages_in_history(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo4"}, + {"role": "system", "content": "foo1"}, + {"role": "unknown", "content": "foo2"}, + {"role": "assistant", "content": "foo3"}, + {"role": "user", "content": "foo4"}, + {"role": "assistant", "content": "foo5"}, + {"role": "user", "content": "foo6"}, + ] + llm = MockLLM() + compressor = StandaloneMessageCompressor(llm, history_len=4) + await compressor.compress(conversation) + user_prompt = llm.calls[0][1] + assert user_prompt["role"] == "user" + content = user_prompt["content"] + assert "foo4" in content + assert "foo5" in content + assert "foo6" in content + assert "foo3" in content + assert "foo1" not in content + assert "foo2" not in content + + +async def test_changing_prompt(): + conversation: ChatFormat = [ + {"role": "user", "content": "foo1"}, + {"role": "assistant", "content": "foo2"}, + {"role": "user", "content": "foo3"}, + ] + llm = MockLLM() + compressor = StandaloneMessageCompressor(llm, prompt=MockPrompt) + await compressor.compress(conversation) + user_prompt = llm.calls[0][0] + assert user_prompt["role"] == "user" + assert user_prompt["content"] == "mock prompt" diff --git a/packages/ragbits-core/src/ragbits/core/llms/mock.py b/packages/ragbits-core/src/ragbits/core/llms/mock.py new file mode 100644 index 00000000..9a82489f --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/llms/mock.py @@ -0,0 +1,75 @@ +from collections.abc import AsyncGenerator + +from pydantic import BaseModel + +from ragbits.core.options import Options +from ragbits.core.prompt import ChatFormat +from ragbits.core.types import NOT_GIVEN, NotGiven + +from .base import LLM + + +class MockLLMOptions(Options): + """ + Options for the MockLLM class. + """ + + response: str | NotGiven = NOT_GIVEN + response_stream: list[str] | NotGiven = NOT_GIVEN + + +class MockLLM(LLM[MockLLMOptions]): + """ + Class for mocking interactions with LLMs - useful for testing. + """ + + options_cls = MockLLMOptions + + def __init__(self, model_name: str = "mock", default_options: MockLLMOptions | None = None) -> None: + """ + Constructs a new MockLLM instance. + + Args: + model_name: Name of the model to be used. + default_options: Default options to be used. + """ + super().__init__(model_name, default_options=default_options) + self.calls: list[ChatFormat] = [] + + async def _call( # noqa: PLR6301 + self, + conversation: ChatFormat, + options: MockLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> str: + """ + Mocks the call to the LLM, using the response from the options if provided. + """ + self.calls.append(conversation) + if not isinstance(options.response, NotGiven): + return options.response + return "mocked response" + + async def _call_streaming( # noqa: PLR6301 + self, + conversation: ChatFormat, + options: MockLLMOptions, + json_mode: bool = False, + output_schema: type[BaseModel] | dict | None = None, + ) -> AsyncGenerator[str, None]: + """ + Mocks the call to the LLM, using the response from the options if provided. + """ + self.calls.append(conversation) + + async def generator() -> AsyncGenerator[str, None]: + if not isinstance(options.response_stream, NotGiven): + for response in options.response_stream: + yield response + elif not isinstance(options.response, NotGiven): + yield options.response + else: + yield "mocked response" + + return generator() diff --git a/packages/ragbits/pyproject.toml b/packages/ragbits/pyproject.toml index cd32f813..29e436a0 100644 --- a/packages/ragbits/pyproject.toml +++ b/packages/ragbits/pyproject.toml @@ -31,7 +31,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Software Development :: Libraries :: Python Modules", ] -dependencies = ["ragbits-document-search==0.6.0", "ragbits-cli==0.6.0", "ragbits-evaluate==0.6.0", "ragbits-guardrails==0.6.0", "ragbits-core==0.6.0"] +dependencies = [ + "ragbits-document-search==0.6.0", + "ragbits-cli==0.6.0", + "ragbits-evaluate==0.6.0", + "ragbits-guardrails==0.6.0", + "ragbits-core==0.6.0", + "ragbits-conversations==0.6.0", +] [project.urls] "Homepage" = "https://github.com/deepsense-ai/ragbits" diff --git a/pyproject.toml b/pyproject.toml index 2502b26b..ea1b2263 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ dependencies = [ "ragbits-document-search[gcs,huggingface,distributed]", "ragbits-evaluate[relari]", "ragbits-guardrails[openai]", + "ragbits-conversations", ] [tool.uv] @@ -38,6 +39,7 @@ ragbits-core = { workspace = true } ragbits-document-search = { workspace = true } ragbits-evaluate = {workspace = true} ragbits-guardrails = {workspace = true} +ragbits-conversations = {workspace = true} [tool.uv.workspace] members = [ @@ -46,6 +48,7 @@ members = [ "packages/ragbits-document-search", "packages/ragbits-evaluate", "packages/ragbits-guardrails", + "packages/ragbits-conversations", ] [tool.pytest] @@ -93,6 +96,7 @@ mypy_path = [ "packages/ragbits-document-search/src", "packages/ragbits-evaluate/src", "packages/ragbits-guardrails/src", + "packages/ragbits-conversations/src", ] exclude = ["scripts"] diff --git a/uv.lock b/uv.lock index 299460ca..2424ad2e 100644 --- a/uv.lock +++ b/uv.lock @@ -11,6 +11,7 @@ resolution-markers = [ [manifest] members = [ "ragbits-cli", + "ragbits-conversations", "ragbits-core", "ragbits-document-search", "ragbits-evaluate", @@ -3845,6 +3846,31 @@ requires-dist = [ { name = "typer", specifier = ">=0.12.5" }, ] +[[package]] +name = "ragbits-conversations" +version = "0.6.0" +source = { editable = "packages/ragbits-conversations" } + +[package.dev-dependencies] +dev = [ + { name = "pip-licenses" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-asyncio" }, + { name = "pytest-cov" }, +] + +[package.metadata] + +[package.metadata.requires-dev] +dev = [ + { name = "pip-licenses", specifier = ">=4.0.0,<5.0.0" }, + { name = "pre-commit", specifier = "~=3.8.0" }, + { name = "pytest", specifier = "~=8.3.3" }, + { name = "pytest-asyncio", specifier = "~=0.24.0" }, + { name = "pytest-cov", specifier = "~=5.0.0" }, +] + [[package]] name = "ragbits-core" version = "0.6.0" @@ -4055,6 +4081,7 @@ version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "ragbits-cli" }, + { name = "ragbits-conversations" }, { name = "ragbits-core", extra = ["chroma", "lab", "local", "otel", "qdrant"] }, { name = "ragbits-document-search", extra = ["distributed", "gcs", "huggingface"] }, { name = "ragbits-evaluate", extra = ["relari"] }, @@ -4084,6 +4111,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "ragbits-cli", editable = "packages/ragbits-cli" }, + { name = "ragbits-conversations", editable = "packages/ragbits-conversations" }, { name = "ragbits-core", extras = ["chroma", "lab", "local", "otel", "qdrant"], editable = "packages/ragbits-core" }, { name = "ragbits-document-search", extras = ["gcs", "huggingface", "distributed"], editable = "packages/ragbits-document-search" }, { name = "ragbits-evaluate", extras = ["relari"], editable = "packages/ragbits-evaluate" },