diff --git a/packages/ragbits-core/examples/prompt_example.py b/packages/ragbits-core/examples/prompt_example.py index 51ce05ea..64c37c36 100644 --- a/packages/ragbits-core/examples/prompt_example.py +++ b/packages/ragbits-core/examples/prompt_example.py @@ -43,8 +43,8 @@ class LoremPrompt(Prompt[LoremPromptInput, LoremPromptOutput]): if __name__ == "__main__": - lorem_prompt = LoremPrompt(LoremPromptInput(theme="business")) - lorem_prompt.add_assistant_message("Lorem Ipsum biznessum dolor copy machinum yearly reportum") + lorem_prompt = LoremPrompt(LoremPromptInput(theme="animals")) + lorem_prompt.add_few_shot("theme: business", "Lorem Ipsum biznessum dolor copy machinum yearly reportum") print("CHAT:") print(lorem_prompt.chat) print() diff --git a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py index 75b3c093..876a68a4 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py @@ -22,7 +22,9 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass= system_prompt: Optional[str] = None user_prompt: str - additional_messages: Optional[ChatFormat] = None + + # Additional messages to be added to the conversation after the system prompt + few_shots: ChatFormat = [] # function that parses the response from the LLM to specific output type # if not provided, the class tries to set it automatically based on the output type @@ -111,10 +113,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: if self.input_type and input_data is None: raise ValueError("Input data must be provided") - self.system_message = ( + self.rendered_system_prompt = ( self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None ) - self.user_message = self._render_template(self.user_prompt_template, input_data) + self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data) + + # Additional few shot examples that can be added dynamically using methods + # (in opposite to the static `few_shots` attribute which is defined in the class) + self._instace_few_shots: ChatFormat = [] super().__init__() @property @@ -126,41 +132,30 @@ def chat(self) -> ChatFormat: ChatFormat: A list of dictionaries, each containing the role and content of a message. """ chat = [ - *([{"role": "system", "content": self.system_message}] if self.system_message is not None else []), - {"role": "user", "content": self.user_message}, + *( + [{"role": "system", "content": self.rendered_system_prompt}] + if self.rendered_system_prompt is not None + else [] + ), + *self.few_shots, + *self._instace_few_shots, + {"role": "user", "content": self.rendered_user_prompt}, ] - if self.additional_messages: - chat.extend(self.additional_messages) return chat - def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": - """ - Add a message from the user to the conversation. - - Args: - message (str): The message to add. - - Returns: - Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. - """ - if self.additional_messages is None: - self.additional_messages = [] - self.additional_messages.append({"role": "user", "content": message}) - return self - - def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]": + def add_few_shot(self, user_message: str, assistant_message: str) -> "Prompt[InputT, OutputT]": """ - Add a message from the assistant to the conversation. + Add a few-shot example to the conversation. Args: - message (str): The message to add. + user_message (str): The message from the user. + assistant_message (str): The message from the assistant. Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ - if self.additional_messages is None: - self.additional_messages = [] - self.additional_messages.append({"role": "assistant", "content": message}) + self._instace_few_shots.append({"role": "user", "content": user_message}) + self._instace_few_shots.append({"role": "assistant", "content": assistant_message}) return self def output_schema(self) -> Optional[Dict | Type[BaseModel]]: diff --git a/packages/ragbits-core/tests/unit/prompts/test_prompt.py b/packages/ragbits-core/tests/unit/prompts/test_prompt.py index 557849f0..3af149bc 100644 --- a/packages/ragbits-core/tests/unit/prompts/test_prompt.py +++ b/packages/ragbits-core/tests/unit/prompts/test_prompt.py @@ -105,7 +105,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable user_prompt = "Hello" prompt = TestPrompt() - assert prompt.user_message == "Hello" + assert prompt.rendered_user_prompt == "Hello" assert prompt.chat == [{"role": "user", "content": "Hello"}] @@ -121,8 +121,8 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable user_prompt = "Theme for the song is {{ theme }}." prompt = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) - assert prompt.system_message == "You are a song generator for a adult named Alice." - assert prompt.user_message == "Theme for the song is rock." + assert prompt.rendered_system_prompt == "You are a song generator for a adult named Alice." + assert prompt.rendered_user_prompt == "Theme for the song is rock." assert prompt.chat == [ {"role": "system", "content": "You are a song generator for a adult named Alice."}, {"role": "user", "content": "Theme for the song is rock."}, @@ -139,8 +139,8 @@ class TestPrompt(Prompt[str, str]): # type: ignore # pylint: disable=unused-var user_prompt = "Hello" -def test_adding_messages(): - """Test that messages can be added to the conversation.""" +def test_defining_few_shots(): + """Test that few shots can be defined for the prompt.""" class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable """A test prompt""" @@ -149,15 +149,68 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. """ user_prompt = "Theme for the song is {{ theme }}." + few_shots = [ + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + ] - prompt = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) - prompt.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + + assert prompt.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "Theme for the song is rock."}, + ] + + +def test_adding_few_shots(): + """Test that few shots can be added to the conversation.""" + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + prompt.add_few_shot("Theme for the song is pop.", "It's a really catchy tune.") assert prompt.chat == [ {"role": "system", "content": "You are a song generator for a child named John."}, {"role": "user", "content": "Theme for the song is pop."}, {"role": "assistant", "content": "It's a really catchy tune."}, - {"role": "user", "content": "I like it."}, + {"role": "user", "content": "Theme for the song is rock."}, + ] + + +def test_defining_and_adding_few_shots(): + """Test that few shots can be defined and added to the conversation.""" + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + few_shots = [ + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + ] + + prompt = TestPrompt(_PromptInput(name="John", age=15, theme="rock")) + prompt.add_few_shot("Theme for the song is experimental underground jazz.", "It's quite hard to dance to.") + + assert prompt.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "Theme for the song is experimental underground jazz."}, + {"role": "assistant", "content": "It's quite hard to dance to."}, + {"role": "user", "content": "Theme for the song is rock."}, ] @@ -173,7 +226,7 @@ class TestPrompt(Prompt): # pylint: disable=unused-variable """ prompt = TestPrompt() - assert prompt.user_message == "Hello\nWorld" + assert prompt.rendered_user_prompt == "Hello\nWorld" def test_output_format(): @@ -220,7 +273,7 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable ] -def test_two_instances_do_not_share_additional_messages(): +def test_two_instances_do_not_share_few_shots(): """ Test that two instances of a prompt do not share additional messages. """ @@ -234,20 +287,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable user_prompt = "Theme for the song is {{ theme }}." prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) - prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + prompt1.add_few_shot("Theme for the song is 80s disco.", "I can't stop dancing.") prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) - prompt2.add_assistant_message("It's a nice tune.") + prompt2.add_few_shot("Theme for the song is 90s pop.", "Why do I know all the words?") assert prompt1.chat == [ {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is 80s disco."}, + {"role": "assistant", "content": "I can't stop dancing."}, {"role": "user", "content": "Theme for the song is pop."}, - {"role": "assistant", "content": "It's a really catchy tune."}, - {"role": "user", "content": "I like it."}, ] assert prompt2.chat == [ {"role": "system", "content": "You are a song generator for a adult named Alice."}, + {"role": "user", "content": "Theme for the song is 90s pop."}, + {"role": "assistant", "content": "Why do I know all the words?"}, {"role": "user", "content": "Theme for the song is rock."}, - {"role": "assistant", "content": "It's a nice tune."}, ]