diff --git a/docs/reference/chat.md b/docs/reference/chat.md new file mode 100644 index 0000000..0b1b193 --- /dev/null +++ b/docs/reference/chat.md @@ -0,0 +1,20 @@ +# Chat history + +## Filter message + +In some situation you may want to filter the messages before building the prompt, for instance to use RAG. In this case you can subclass `Chat` and override the `filter` method: + + +```python +from prompts import Chat + +class RAGChat(Chat): + + def filter(self): + filtered_message = [] + for message in filtered_message: + if message.role == "user" and "Hi" in message.content: + filtered_message.append(message) + + return filtered_messages +``` diff --git a/mkdocs.yml b/mkdocs.yml index 9d1f2b8..4c788a2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -76,3 +76,4 @@ nav: - Prompt template: reference/template.md - Dispatch: reference/dispatch.md - Special tokens: reference/special_tokens.md + - Chat History: reference/chat.md diff --git a/prompts/chat.py b/prompts/chat.py new file mode 100644 index 0000000..e601432 --- /dev/null +++ b/prompts/chat.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel +from typing_extensions import TypedDict + + +class Document(TypedDict): + title: str + text: str + + +class Role(Enum): + system = "system" + user = "user" + assistant = "assistant" + + +@dataclass +class Message: + role: Role + content: str + + +class Chat: + def __init__( + self, + system_msg: Optional[str] = None, + tools: Optional[List[BaseModel]] = None, + documents: Optional[List[Document]] = None, + history: List[Message] = [], + ): + self.history = history + self.system = system_msg + self.tools = tools + self.documents = documents + + @property + def trimmed_history(self): + return self.history + + def __add__(self, other: Message): + history = self.history + history.append(other) + return Chat(self.system, self.tools, self.documents, history=history) + + def __radd__(self, other: Message): + history = self.history + history.append(other) + return Chat(self.system, self.tools, self.documents, history=history) + + def __iadd__(self, other: Message): + self.history.append(other) + return self + + def __getitem__(self, key): + if isinstance(key, int): + return self.history[key] + else: + raise KeyError() + + def render(self, model_name: str): + """Render the conversation using the model's chat template. + + TODO: Do this ourselves. + + Parameters + ---------- + model_name + The name of the model whose chat template we need to use. + + """ + from transformers import AutoTokenizer + + conversation = [] + if self.system is not None: + conversation.append({"role": "system", "content": self.system}) + for message in self.trimmed_history: + conversation.append({"role": message.role, "content": message.content}) + + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + + return self.tokenizer.apply_chat_template( + conversation, self.tools, self.documents + ) diff --git a/prompts/tokens.py b/prompts/tokens.py new file mode 100644 index 0000000..953ad9f --- /dev/null +++ b/prompts/tokens.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import Dict, Optional + + +@dataclass +class Limits: + begin: str = "" + end: str = "" + + +@dataclass +class Special: + sequence: Limits = Limits("", "") + user: Limits = Limits("", "") + assistant: Limits = Limits("", "") + system: Limits = Limits("", "") + + +SPECIAL_TOKENS: Dict[Optional[str], Special] = { + None: Special(), + "google/gemma-2-9b": Special(Limits("", "")), + "openai-community/gpt2": Special(Limits("", "<|endoftext|>")), + "mistralai/Mistral-7B-v0.1": Special(Limits("", "")), + "mistralai/Mistral-7B-Instruct-v0.1": Special( + Limits("", ""), + Limits("[INST]", "[/INST]"), + Limits("", ""), + ), +} diff --git a/pyproject.toml b/pyproject.toml index 326a183..b72372c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ version = "0.1.0" description = "Large Language Models prompting library" authors = [{name = "The Outlines developers", email = "contact@dottxt.co"}] requires-python = ">= 3.8" -dependencies = ["jinja2"] +dependencies = ["jinja2", "pydantic", "transformers"] [build-system] requires = ["setuptools>=45", "setuptools_scm[toml]>=6.2"] @@ -35,5 +35,5 @@ file="README.md" content-type = "text/markdown" [[tool.mypy.overrides]] -module = ["jinja2", "pytest"] +module = ["jinja2", "pydantic", "pytest", "transformers"] ignore_missing_imports = true diff --git a/tests/test_chat.py b/tests/test_chat.py new file mode 100644 index 0000000..295dc30 --- /dev/null +++ b/tests/test_chat.py @@ -0,0 +1,7 @@ +from prompts.chat import Chat, Message + + +def test_simple(): + chat = Chat("system message") + new_chat = chat + Message("user", "new user message") + new_chat += Message("assistant", "new assistant message")