Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(prompts): Make the Prompt interface more clear in regard to messages #59

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions packages/ragbits-core/examples/prompt_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class LoremPrompt(Prompt[LoremPromptInput, LoremPromptOutput]):


if __name__ == "__main__":
lorem_prompt = LoremPrompt(LoremPromptInput(theme="business"))
lorem_prompt.add_assistant_message("Lorem Ipsum biznessum dolor copy machinum yearly reportum")
lorem_prompt = LoremPrompt(LoremPromptInput(theme="animals"))
lorem_prompt.add_few_shot("theme: business", "Lorem Ipsum biznessum dolor copy machinum yearly reportum")
print("CHAT:")
print(lorem_prompt.chat)
print()
Expand Down
51 changes: 23 additions & 28 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=

system_prompt: Optional[str] = None
user_prompt: str
additional_messages: Optional[ChatFormat] = None

# Additional messages to be added to the conversation after the system prompt
few_shots: ChatFormat = []

# function that parses the response from the LLM to specific output type
# if not provided, the class tries to set it automatically based on the output type
Expand Down Expand Up @@ -111,10 +113,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if self.input_type and input_data is None:
raise ValueError("Input data must be provided")

self.system_message = (
self.rendered_system_prompt = (
self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
)
self.user_message = self._render_template(self.user_prompt_template, input_data)
self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)

# Additional few shot examples that can be added dynamically using methods
# (in opposite to the static `few_shots` attribute which is defined in the class)
self._instace_few_shots: ChatFormat = []
super().__init__()

@property
Expand All @@ -126,41 +132,30 @@ def chat(self) -> ChatFormat:
ChatFormat: A list of dictionaries, each containing the role and content of a message.
"""
chat = [
*([{"role": "system", "content": self.system_message}] if self.system_message is not None else []),
{"role": "user", "content": self.user_message},
*(
[{"role": "system", "content": self.rendered_system_prompt}]
if self.rendered_system_prompt is not None
else []
),
*self.few_shots,
*self._instace_few_shots,
{"role": "user", "content": self.rendered_user_prompt},
]
if self.additional_messages:
chat.extend(self.additional_messages)
return chat

def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the user to the conversation.

Args:
message (str): The message to add.

Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "user", "content": message})
return self

def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]":
def add_few_shot(self, user_message: str, assistant_message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the assistant to the conversation.
Add a few-shot example to the conversation.

Args:
message (str): The message to add.
user_message (str): The message from the user.
assistant_message (str): The message from the assistant.

Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "assistant", "content": message})
self._instace_few_shots.append({"role": "user", "content": user_message})
self._instace_few_shots.append({"role": "assistant", "content": assistant_message})
return self

def output_schema(self) -> Optional[Dict | Type[BaseModel]]:
Expand Down
84 changes: 69 additions & 15 deletions packages/ragbits-core/tests/unit/prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable
user_prompt = "Hello"

prompt = TestPrompt()
assert prompt.user_message == "Hello"
assert prompt.rendered_user_prompt == "Hello"
assert prompt.chat == [{"role": "user", "content": "Hello"}]


Expand All @@ -121,8 +121,8 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
user_prompt = "Theme for the song is {{ theme }}."

prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
assert prompt.system_message == "You are a song generator for a adult named Alice."
assert prompt.user_message == "Theme for the song is rock."
assert prompt.rendered_system_prompt == "You are a song generator for a adult named Alice."
assert prompt.rendered_user_prompt == "Theme for the song is rock."
assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a adult named Alice."},
{"role": "user", "content": "Theme for the song is rock."},
Expand All @@ -139,8 +139,8 @@ class TestPrompt(Prompt[str, str]): # type: ignore # pylint: disable=unused-var
user_prompt = "Hello"


def test_adding_messages():
"""Test that messages can be added to the conversation."""
def test_defining_few_shots():
"""Test that few shots can be defined for the prompt."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""
Expand All @@ -149,15 +149,68 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."
few_shots = [
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
]

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="pop"))
prompt.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.")
prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "Theme for the song is rock."},
]


def test_adding_few_shots():
"""Test that few shots can be added to the conversation."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""

system_prompt = """
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))
prompt.add_few_shot("Theme for the song is pop.", "It's a really catchy tune.")

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "I like it."},
{"role": "user", "content": "Theme for the song is rock."},
]


def test_defining_and_adding_few_shots():
"""Test that few shots can be defined and added to the conversation."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""

system_prompt = """
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."
few_shots = [
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
]

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))
prompt.add_few_shot("Theme for the song is experimental underground jazz.", "It's quite hard to dance to.")

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "Theme for the song is experimental underground jazz."},
{"role": "assistant", "content": "It's quite hard to dance to."},
{"role": "user", "content": "Theme for the song is rock."},
]


Expand All @@ -173,7 +226,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable
"""

prompt = TestPrompt()
assert prompt.user_message == "Hello\nWorld"
assert prompt.rendered_user_prompt == "Hello\nWorld"


def test_output_format():
Expand Down Expand Up @@ -220,7 +273,7 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
]


def test_two_instances_do_not_share_additional_messages():
def test_two_instances_do_not_share_few_shots():
"""
Test that two instances of a prompt do not share additional messages.
"""
Expand All @@ -234,20 +287,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
user_prompt = "Theme for the song is {{ theme }}."

prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop"))
prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.")
prompt1.add_few_shot("Theme for the song is 80s disco.", "I can't stop dancing.")

prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
prompt2.add_assistant_message("It's a nice tune.")
prompt2.add_few_shot("Theme for the song is 90s pop.", "Why do I know all the words?")

assert prompt1.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is 80s disco."},
{"role": "assistant", "content": "I can't stop dancing."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "I like it."},
]

assert prompt2.chat == [
{"role": "system", "content": "You are a song generator for a adult named Alice."},
{"role": "user", "content": "Theme for the song is 90s pop."},
{"role": "assistant", "content": "Why do I know all the words?"},
{"role": "user", "content": "Theme for the song is rock."},
{"role": "assistant", "content": "It's a nice tune."},
]
Loading