diff --git a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py index 443fb280..75b3c093 100644 --- a/packages/ragbits-core/src/ragbits/core/prompt/prompt.py +++ b/packages/ragbits-core/src/ragbits/core/prompt/prompt.py @@ -22,7 +22,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass= system_prompt: Optional[str] = None user_prompt: str - additional_messages: ChatFormat = [] + additional_messages: Optional[ChatFormat] = None # function that parses the response from the LLM to specific output type # if not provided, the class tries to set it automatically based on the output type @@ -125,10 +125,13 @@ def chat(self) -> ChatFormat: Returns: ChatFormat: A list of dictionaries, each containing the role and content of a message. """ - return [ + chat = [ *([{"role": "system", "content": self.system_message}] if self.system_message is not None else []), {"role": "user", "content": self.user_message}, - ] + self.additional_messages + ] + if self.additional_messages: + chat.extend(self.additional_messages) + return chat def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": """ @@ -140,6 +143,8 @@ def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]": Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ + if self.additional_messages is None: + self.additional_messages = [] self.additional_messages.append({"role": "user", "content": message}) return self @@ -153,6 +158,8 @@ def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]": Returns: Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining. """ + if self.additional_messages is None: + self.additional_messages = [] self.additional_messages.append({"role": "assistant", "content": message}) return self diff --git a/packages/ragbits-core/tests/unit/prompts/test_prompt.py b/packages/ragbits-core/tests/unit/prompts/test_prompt.py index f0e00ebf..557849f0 100644 --- a/packages/ragbits-core/tests/unit/prompts/test_prompt.py +++ b/packages/ragbits-core/tests/unit/prompts/test_prompt.py @@ -218,3 +218,36 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable {"role": "system", "content": "You are a song generator for a adult named John."}, {"role": "user", "content": "Theme for the song is pop."}, ] + + +def test_two_instances_do_not_share_additional_messages(): + """ + Test that two instances of a prompt do not share additional messages. + """ + + class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable + """A test prompt""" + + system_prompt = """ + You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}. + """ + user_prompt = "Theme for the song is {{ theme }}." + + prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop")) + prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.") + + prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock")) + prompt2.add_assistant_message("It's a nice tune.") + + assert prompt1.chat == [ + {"role": "system", "content": "You are a song generator for a child named John."}, + {"role": "user", "content": "Theme for the song is pop."}, + {"role": "assistant", "content": "It's a really catchy tune."}, + {"role": "user", "content": "I like it."}, + ] + + assert prompt2.chat == [ + {"role": "system", "content": "You are a song generator for a adult named Alice."}, + {"role": "user", "content": "Theme for the song is rock."}, + {"role": "assistant", "content": "It's a nice tune."}, + ]