Skip to content

Commit

Permalink
feat(prompts): Make the Prompt interface more clear in regard to me…
Browse files Browse the repository at this point in the history
…ssages
  • Loading branch information
ludwiktrammer committed Oct 3, 2024
1 parent 2857978 commit de34ee2
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 43 deletions.
51 changes: 23 additions & 28 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=

system_prompt: Optional[str] = None
user_prompt: str
additional_messages: Optional[ChatFormat] = None

# Additional messages to be added to the conversation after the system prompt
few_shots: ChatFormat = []

# function that parses the response from the LLM to specific output type
# if not provided, the class tries to set it automatically based on the output type
Expand Down Expand Up @@ -111,10 +113,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
if self.input_type and input_data is None:
raise ValueError("Input data must be provided")

self.system_message = (
self.rendered_system_prompt = (
self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
)
self.user_message = self._render_template(self.user_prompt_template, input_data)
self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)

# Additional few shot examples that can be added dynamically using methods
# (in opposite to the static `few_shots` attribute which is defined in the class)
self._instace_few_shots: ChatFormat = []
super().__init__()

@property
Expand All @@ -126,41 +132,30 @@ def chat(self) -> ChatFormat:
ChatFormat: A list of dictionaries, each containing the role and content of a message.
"""
chat = [
*([{"role": "system", "content": self.system_message}] if self.system_message is not None else []),
{"role": "user", "content": self.user_message},
*(
[{"role": "system", "content": self.rendered_system_prompt}]
if self.rendered_system_prompt is not None
else []
),
*self.few_shots,
*self._instace_few_shots,
{"role": "user", "content": self.rendered_user_prompt},
]
if self.additional_messages:
chat.extend(self.additional_messages)
return chat

def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the user to the conversation.
Args:
message (str): The message to add.
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "user", "content": message})
return self

def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]":
def add_few_shot(self, user_message: str, assistant_message: str) -> "Prompt[InputT, OutputT]":
"""
Add a message from the assistant to the conversation.
Add a few-shot example to the conversation.
Args:
message (str): The message to add.
user_message (str): The message from the user.
assistant_message (str): The message from the assistant.
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "assistant", "content": message})
self._instace_few_shots.append({"role": "user", "content": user_message})
self._instace_few_shots.append({"role": "assistant", "content": assistant_message})
return self

def output_schema(self) -> Optional[Dict | Type[BaseModel]]:
Expand Down
84 changes: 69 additions & 15 deletions packages/ragbits-core/tests/unit/prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable
user_prompt = "Hello"

prompt = TestPrompt()
assert prompt.user_message == "Hello"
assert prompt.rendered_user_prompt == "Hello"
assert prompt.chat == [{"role": "user", "content": "Hello"}]


Expand All @@ -121,8 +121,8 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
user_prompt = "Theme for the song is {{ theme }}."

prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
assert prompt.system_message == "You are a song generator for a adult named Alice."
assert prompt.user_message == "Theme for the song is rock."
assert prompt.rendered_system_prompt == "You are a song generator for a adult named Alice."
assert prompt.rendered_user_prompt == "Theme for the song is rock."
assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a adult named Alice."},
{"role": "user", "content": "Theme for the song is rock."},
Expand All @@ -139,8 +139,8 @@ class TestPrompt(Prompt[str, str]): # type: ignore # pylint: disable=unused-var
user_prompt = "Hello"


def test_adding_messages():
"""Test that messages can be added to the conversation."""
def test_defining_few_shots():
"""Test that few shots can be defined for the prompt."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""
Expand All @@ -149,15 +149,68 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."
few_shots = [
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
]

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="pop"))
prompt.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.")
prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "Theme for the song is rock."},
]


def test_adding_few_shots():
"""Test that few shots can be added to the conversation."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""

system_prompt = """
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))
prompt.add_few_shot("Theme for the song is pop.", "It's a really catchy tune.")

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "I like it."},
{"role": "user", "content": "Theme for the song is rock."},
]


def test_defining_and_adding_few_shots():
"""Test that few shots can be defined and added to the conversation."""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""

system_prompt = """
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."
few_shots = [
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
]

prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock"))
prompt.add_few_shot("Theme for the song is experimental underground jazz.", "It's quite hard to dance to.")

assert prompt.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "Theme for the song is experimental underground jazz."},
{"role": "assistant", "content": "It's quite hard to dance to."},
{"role": "user", "content": "Theme for the song is rock."},
]


Expand All @@ -173,7 +226,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable
"""

prompt = TestPrompt()
assert prompt.user_message == "Hello\nWorld"
assert prompt.rendered_user_prompt == "Hello\nWorld"


def test_output_format():
Expand Down Expand Up @@ -220,7 +273,7 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
]


def test_two_instances_do_not_share_additional_messages():
def test_two_instances_do_not_share_few_shots():
"""
Test that two instances of a prompt do not share additional messages.
"""
Expand All @@ -234,20 +287,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
user_prompt = "Theme for the song is {{ theme }}."

prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop"))
prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.")
prompt1.add_few_shot("Theme for the song is 80s disco.", "I can't stop dancing.")

prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
prompt2.add_assistant_message("It's a nice tune.")
prompt2.add_few_shot("Theme for the song is 90s pop.", "Why do I know all the words?")

assert prompt1.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is 80s disco."},
{"role": "assistant", "content": "I can't stop dancing."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "I like it."},
]

assert prompt2.chat == [
{"role": "system", "content": "You are a song generator for a adult named Alice."},
{"role": "user", "content": "Theme for the song is 90s pop."},
{"role": "assistant", "content": "Why do I know all the words?"},
{"role": "user", "content": "Theme for the song is rock."},
{"role": "assistant", "content": "It's a nice tune."},
]

0 comments on commit de34ee2

Please sign in to comment.