Skip to content

Commit

Permalink
Fix list as default for additional_messages var
Browse files Browse the repository at this point in the history
  • Loading branch information
akonarski-ds committed Oct 2, 2024
1 parent 57413e9 commit 5419630
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 3 deletions.
13 changes: 10 additions & 3 deletions packages/ragbits-core/src/ragbits/core/prompt/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Prompt(Generic[InputT, OutputT], BasePromptWithParser[OutputT], metaclass=

system_prompt: Optional[str] = None
user_prompt: str
additional_messages: ChatFormat = []
additional_messages: Optional[ChatFormat] = None

# function that parses the response from the LLM to specific output type
# if not provided, the class tries to set it automatically based on the output type
Expand Down Expand Up @@ -125,10 +125,13 @@ def chat(self) -> ChatFormat:
Returns:
ChatFormat: A list of dictionaries, each containing the role and content of a message.
"""
return [
chat = [
*([{"role": "system", "content": self.system_message}] if self.system_message is not None else []),
{"role": "user", "content": self.user_message},
] + self.additional_messages
]
if self.additional_messages:
chat.extend(self.additional_messages)
return chat

def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]":
"""
Expand All @@ -140,6 +143,8 @@ def add_user_message(self, message: str) -> "Prompt[InputT, OutputT]":
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "user", "content": message})
return self

Expand All @@ -153,6 +158,8 @@ def add_assistant_message(self, message: str) -> "Prompt[InputT, OutputT]":
Returns:
Prompt[InputT, OutputT]: The current prompt instance in order to allow chaining.
"""
if self.additional_messages is None:
self.additional_messages = []
self.additional_messages.append({"role": "assistant", "content": message})
return self

Expand Down
33 changes: 33 additions & 0 deletions packages/ragbits-core/tests/unit/prompts/test_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,3 +218,36 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
{"role": "system", "content": "You are a song generator for a adult named John."},
{"role": "user", "content": "Theme for the song is pop."},
]


def test_two_instances_do_not_share_additional_messages():
"""
Test that two instances of a prompt do not share additional messages.
"""

class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
"""A test prompt"""

system_prompt = """
You are a song generator for a {% if age > 18 %}adult{% else %}child{% endif %} named {{ name }}.
"""
user_prompt = "Theme for the song is {{ theme }}."

prompt1 = TestPrompt(_PromptInput(name="John", age=15, theme="pop"))
prompt1.add_assistant_message("It's a really catchy tune.").add_user_message("I like it.")

prompt2 = TestPrompt(_PromptInput(name="Alice", age=30, theme="rock"))
prompt2.add_assistant_message("It's a nice tune.")

assert prompt1.chat == [
{"role": "system", "content": "You are a song generator for a child named John."},
{"role": "user", "content": "Theme for the song is pop."},
{"role": "assistant", "content": "It's a really catchy tune."},
{"role": "user", "content": "I like it."},
]

assert prompt2.chat == [
{"role": "system", "content": "You are a song generator for a adult named Alice."},
{"role": "user", "content": "Theme for the song is rock."},
{"role": "assistant", "content": "It's a nice tune."},
]

0 comments on commit 5419630

Please sign in to comment.