Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(conversations): add last message recontextualizer #271

Merged
merged 7 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/conversations/recontextualize_message.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Ragbits Conversations Example: Recontextualize Last Message

This example demonstrates how to use the `StandaloneMessageCompressor` compressor to recontextualize
the last message in a conversation history.
"""

# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-conversations",
# ]
# ///

import asyncio

from ragbits.conversations.history.compressors.llm import StandaloneMessageCompressor
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.prompt import ChatFormat

# Example conversation history
conversation: ChatFormat = [
{"role": "user", "content": "Who's working on Friday?"},
{"role": "assistant", "content": "Jim"},
{"role": "user", "content": "Where is he based?"},
{"role": "assistant", "content": "At our California Head Office"},
{"role": "user", "content": "Is he a senior staff member?"},
{"role": "assistant", "content": "Yes, he's a senior manager"},
{"role": "user", "content": "What's his phone number (including the prefix for his state)?"},
]


async def main() -> None:
"""
Main function to demonstrate the StandaloneMessageCompressor compressor.
"""
# Initialize the LiteLLM client
llm = LiteLLM("gpt-4o")

# Initialize the StandaloneMessageCompressor compressor
compressor = StandaloneMessageCompressor(llm, history_len=10)

# Compress the conversation history
recontextualized_message = await compressor.compress(conversation)

# Print the recontextualized message
print("Recontextualized Message:")
print(recontextualized_message)


if __name__ == "__main__":
asyncio.run(main())
1 change: 1 addition & 0 deletions packages/ragbits-conversations/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Ragbits Conversation
63 changes: 63 additions & 0 deletions packages/ragbits-conversations/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
[project]
name = "ragbits-conversations"
version = "0.6.0"
description = "Building blocks for rapid development of GenAI applications"
readme = "README.md"
requires-python = ">=3.10"
license = "MIT"
authors = [
{ name = "deepsense.ai", email = "[email protected]"}
]
keywords = [
"Retrieval Augmented Generation",
"RAG",
"Large Language Models",
"LLMs",
"Generative AI",
"GenAI",
"Prompt Management"
]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules",
]
dependencies = []

[project.urls]
"Homepage" = "https://github.com/deepsense-ai/ragbits"
"Bug Reports" = "https://github.com/deepsense-ai/ragbits/issues"
"Documentation" = "https://ragbits.deepsense.ai/"
"Source" = "https://github.com/deepsense-ai/ragbits"

[project.optional-dependencies]
[tool.uv]
dev-dependencies = [
"pre-commit~=3.8.0",
"pytest~=8.3.3",
"pytest-cov~=5.0.0",
"pytest-asyncio~=0.24.0",
"pip-licenses>=4.0.0,<5.0.0"
]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.metadata]
allow-direct-references = true

[tool.hatch.build.targets.wheel]
packages = ["src/ragbits"]

[tool.pytest.ini_options]
asyncio_mode = "auto"
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import ConversationHistoryCompressor
from .llm import StandaloneMessageCompressor

__all__ = ["ConversationHistoryCompressor", "StandaloneMessageCompressor"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from abc import ABC, abstractmethod
from typing import ClassVar

from ragbits.conversations.history import compressors
from ragbits.core.prompt.base import ChatFormat
from ragbits.core.utils.config_handling import WithConstructionConfig


class ConversationHistoryCompressor(WithConstructionConfig, ABC):
"""
An abstract class for conversation history compressors,
i.e. class that takes the entire conversation history
and returns a single string representation of it.

The exact logic of what the string should include and represent
depends on the specific implementation.

Usually used to provide LLM additional context from the conversation history.
"""

default_module: ClassVar = compressors
configuration_key: ClassVar = "history_compressor"

@abstractmethod
async def compress(self, conversation: ChatFormat) -> str:
"""
Compresses the conversation history to a single string.

Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from pydantic import BaseModel

from ragbits.conversations.history.compressors import ConversationHistoryCompressor
from ragbits.core.llms.base import LLM
from ragbits.core.prompt import ChatFormat, Prompt


class LastMessageAndHistory(BaseModel):
"""
A class representing the last message and the history of messages.
"""

last_message: str
history: list[str]


class StandaloneMessageCompressorPrompt(Prompt[LastMessageAndHistory, str]):
"""
A prompt for recontextualizing the last message in the history.
"""

system_prompt = """
Given a new message and a history of the conversation, create a standalone version of the message.
If the message references any context from history, it should be added to the message itself.
Return only the recontextualized message.
Do NOT return the history, do NOT answer the question, and do NOT add context irrelevant to the message.
"""

user_prompt = """
Message:
{{ last_message }}

History:
{% for message in history %}
* {{ message }}
{% endfor %}
"""


class StandaloneMessageCompressor(ConversationHistoryCompressor):
"""
A compressor that uses LLM to recontextualize the last message in the history,
i.e. create a standalone version of the message that includes necessary context.
"""

def __init__(self, llm: LLM, history_len: int = 5, prompt: type[Prompt[LastMessageAndHistory, str]] | None = None):
"""
Initialize the StandaloneMessageCompressor compressor with a LLM.

Args:
llm: A LLM instance to handle recontextualizing the last message.
history_len: The number of previous messages to include in the history.
prompt: The prompt to use for recontextualizing the last message.
"""
self._llm = llm
self._history_len = history_len
self._prompt = prompt or StandaloneMessageCompressorPrompt

async def compress(self, conversation: ChatFormat) -> str:
"""
Contextualize the last message in the conversation history.

Args:
conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
The most recent message should be from the user.
"""
if len(conversation) == 0:
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Conversation history is empty.")

last_message = conversation[-1]
if last_message["role"] != "user":
raise ValueError("StandaloneMessageCompressor expects the last message to be from the user.")

# Only include "user" and "assistant" messages in the history
other_messages = [message for message in conversation[:-1] if message["role"] in ["user", "assistant"]]

if not other_messages:
# No history to use for recontextualization, simply return the user message
return last_message["content"]

history = [f"{message['role']}: {message['content']}" for message in other_messages[-self._history_len :]]

input_data = LastMessageAndHistory(last_message=last_message["content"], history=history)
prompt = self._prompt(input_data)
response = await self._llm.generate(prompt)
return response
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import pytest

from ragbits.conversations.history.compressors.llm import LastMessageAndHistory, StandaloneMessageCompressor
from ragbits.core.llms.mock import MockLLM, MockLLMOptions
from ragbits.core.prompt import ChatFormat
from ragbits.core.prompt.prompt import Prompt


class MockPrompt(Prompt[LastMessageAndHistory, str]):
user_prompt = "mock prompt"


async def test_messages_included():
conversation: ChatFormat = [
{"role": "user", "content": "foo1"},
{"role": "assistant", "content": "foo2"},
{"role": "user", "content": "foo3"},
]
llm = MockLLM(default_options=MockLLMOptions(response="some answer"))
compressor = StandaloneMessageCompressor(llm)
answer = await compressor.compress(conversation)
assert answer == "some answer"
user_prompt = llm.calls[0][1]
assert user_prompt["role"] == "user"
content = user_prompt["content"]
assert "foo1" in content
assert "foo2" in content
assert "foo3" in content


async def test_no_messages():
conversation: ChatFormat = []
compressor = StandaloneMessageCompressor(MockLLM())

with pytest.raises(ValueError):
await compressor.compress(conversation)


async def test_last_message_not_user():
conversation: ChatFormat = [
{"role": "assistant", "content": "foo2"},
]
compressor = StandaloneMessageCompressor(MockLLM())

with pytest.raises(ValueError):
await compressor.compress(conversation)


async def test_history_len():
conversation: ChatFormat = [
{"role": "user", "content": "foo1"},
{"role": "assistant", "content": "foo2"},
{"role": "user", "content": "foo3"},
{"role": "user", "content": "foo4"},
{"role": "user", "content": "foo5"},
]
llm = MockLLM()
compressor = StandaloneMessageCompressor(llm, history_len=3)
await compressor.compress(conversation)
user_prompt = llm.calls[0][1]
assert user_prompt["role"] == "user"
content = user_prompt["content"]

# The rephrased message should be included
assert "foo5" in content

# Three previous messages should be included
assert "foo2" in content
assert "foo3" in content
assert "foo4" in content

# Earlier messages should not be included
assert "foo1" not in content


async def test_only_user_and_assistant_messages_in_history():
conversation: ChatFormat = [
{"role": "user", "content": "foo4"},
{"role": "system", "content": "foo1"},
{"role": "unknown", "content": "foo2"},
{"role": "assistant", "content": "foo3"},
{"role": "user", "content": "foo4"},
{"role": "assistant", "content": "foo5"},
{"role": "user", "content": "foo6"},
]
llm = MockLLM()
compressor = StandaloneMessageCompressor(llm, history_len=4)
await compressor.compress(conversation)
user_prompt = llm.calls[0][1]
assert user_prompt["role"] == "user"
content = user_prompt["content"]
assert "foo4" in content
assert "foo5" in content
assert "foo6" in content
assert "foo3" in content
assert "foo1" not in content
assert "foo2" not in content


async def test_changing_prompt():
conversation: ChatFormat = [
{"role": "user", "content": "foo1"},
{"role": "assistant", "content": "foo2"},
{"role": "user", "content": "foo3"},
]
llm = MockLLM()
compressor = StandaloneMessageCompressor(llm, prompt=MockPrompt)
await compressor.compress(conversation)
user_prompt = llm.calls[0][0]
assert user_prompt["role"] == "user"
assert user_prompt["content"] == "mock prompt"
Loading
Loading