Skip to content

Commit

Permalink
feat(chat): hacking on some RAG chat, playing with TUI (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
skyl authored Nov 21, 2024
1 parent 1010c6a commit ab7b60b
Show file tree
Hide file tree
Showing 16 changed files with 497 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ services:
PYTHONPATH: "/workspace/py/packages"
REDIS_URL: "redis://corpora-redis:6379/0"
OPENAI_API_KEY: "${OPENAI_API_KEY}"
OPENAI_AZURE_ENDPOINT: "${OPENAI_AZURE_ENDPOINT}"
command: python manage.py runserver 0.0.0.0:8877
working_dir: /workspace/py/packages/corpora_proj
depends_on:
Expand All @@ -31,6 +32,7 @@ services:
PYTHONPATH: "/workspace/py/packages"
REDIS_URL: "redis://corpora-redis:6379/0"
OPENAI_API_KEY: "${OPENAI_API_KEY}"
OPENAI_AZURE_ENDPOINT: "${OPENAI_AZURE_ENDPOINT}"
depends_on:
corpora-redis:
condition: service_healthy
Expand Down
3 changes: 3 additions & 0 deletions py/packages/corpora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def get_relevant_splits(self, text: str, limit: int = 10):
from corpora_ai.provider_loader import load_llm_provider

llm = load_llm_provider()
# better_text = llm.get_synthetic_embedding_text(text)
# print(f"better_text: {better_text}")
# vector = llm.get_embedding(better_text)
vector = llm.get_embedding(text)
return (
Split.objects.filter(
Expand Down
6 changes: 2 additions & 4 deletions py/packages/corpora/routers/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from corpora.schema.chat import CorpusChatSchema, get_additional_context
from corpora_ai.llm_interface import ChatCompletionTextMessage
from corpora_ai.provider_loader import load_llm_provider
from corpora_ai.prompts import CHAT_SYSTEM_MESSAGE

from ..auth import BearerAuth
from ..lib.dj.decorators import async_raise_not_found
Expand All @@ -24,9 +25,6 @@
corpus_router = Router(tags=["corpus"], auth=BearerAuth())


CHAT_SYSTEM_MESSAGE = "Be imaginative and creative in answering the user's questions. "


class CorpusUpdateFilesSchema(BaseModel):
delete_files: Optional[List[str]] = None

Expand Down Expand Up @@ -81,7 +79,7 @@ async def chat(request, payload: CorpusChatSchema):
all_messages = [
ChatCompletionTextMessage(
role="system",
text=f"You can explain everything in the {corpus.name} corpus. "
text=f"You are helping the user understand or evolve the **{corpus.name}** project. "
f"{CHAT_SYSTEM_MESSAGE}"
f"{get_additional_context(payload)}",
),
Expand Down
2 changes: 1 addition & 1 deletion py/packages/corpora/tasks/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def process_tarball(corpus_id: str, tarball: bytes) -> None:
corpus_file.save()
corpus_file.splits.all().delete()

generate_summary_task.delay(corpus_file.id)
# generate_summary_task.delay(corpus_file.id)
split_file_task.delay(corpus_file.id)


Expand Down
3 changes: 2 additions & 1 deletion py/packages/corpora/tasks/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_process_tarball(
# assert mock_file.checksum == "mock_checksum"
mock_file.save.assert_called_once()
mock_file.splits.all().delete.assert_called_once()
mock_summary_task.assert_called_once_with(mock_file.id)
# over-specified - we don't even use the summary in the app yet
# mock_summary_task.assert_called_once_with(mock_file.id)
mock_split_task.assert_called_once_with(mock_file.id)

@mock.patch("corpora.models.CorpusTextFile.objects.get")
Expand Down
23 changes: 22 additions & 1 deletion py/packages/corpora_ai/llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from typing import List, Type, TypeVar
from pydantic import BaseModel

from corpora_ai.prompts import SUMMARIZE_SYSTEM_MESSAGE
from corpora_ai.prompts import (
SUMMARIZE_SYSTEM_MESSAGE,
SYNTHETIC_EMBEDDING_SYSTEM_MESSAGE,
)

T = TypeVar("T", bound=BaseModel)

Expand Down Expand Up @@ -87,3 +90,21 @@ def get_summary(self, text: str) -> str:
),
]
)

def get_synthetic_embedding_text(self, text: str) -> str:
"""
Given a short prompt, generate a synthetic embedding text
that is more like to match splits in the corpus.
"""
return self.get_text_completion(
[
ChatCompletionTextMessage(
role="system",
text=SYNTHETIC_EMBEDDING_SYSTEM_MESSAGE,
),
ChatCompletionTextMessage(
role="user",
text=text,
),
]
)
15 changes: 15 additions & 0 deletions py/packages/corpora_ai/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,18 @@
"short, cohesive, and representative of the original text's core message, making "
"it suitable for semantic search and relevance matching."
)

SYNTHETIC_EMBEDDING_SYSTEM_MESSAGE = (
"Transform the input text to maximize its utility for vector matching within the larger corpus. "
"Expand short text with detailed descriptions, relevant context, specific terminology, and meaningful questions to enrich it. "
"For longer or verbose input, refine and compress it while preserving essential terms, intent, and relevance for search. "
"Focus on improving embedding precision by enhancing or maintaining key vocabulary while keeping token usage efficient."
)

CHAT_SYSTEM_MESSAGE = (
"You are an active collaborator with the user, working together to evolve and improve the corpus. "
"Treat the corpus as a shared resource that you have access to, avoiding speculative statements like 'assuming X' or 'if Y is being used.' "
"Instead, operate with the understanding that you and the user can directly investigate and refine any part of the corpus as needed. "
"If additional information is missing or unclear, propose exploring specific parts of the corpus or ask the user directly. "
"Provide precise, actionable insights to help refine and expand the corpus in ways that align with its goals."
)
3 changes: 3 additions & 0 deletions py/packages/corpora_ai/provider_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def load_llm_provider() -> Optional[LLMBaseInterface]:
Optional[LLMBaseInterface]: An instance of the best available LLM provider.
"""
provider_name = os.getenv("LLM_PROVIDER", "openai")
# TODO: we need to specify the model in the interface really
# model_name = os.getenv("LLM_MODEL", "gpt-4o-mini")

# Check for the OpenAI provider
if provider_name == "openai" and OpenAIClient:
Expand All @@ -27,6 +29,7 @@ def load_llm_provider() -> Optional[LLMBaseInterface]:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
return OpenAIClient(
api_key=api_key,
# completion_model=model_name,
azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT", None),
)

Expand Down
8 changes: 6 additions & 2 deletions py/packages/corpora_ai_openai/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class OpenAIClient(LLMBaseInterface):
def __init__(
self,
api_key: str,
# TODO: we probably do need some way
# to specify in runtime which model to use ;/
# I think we will have to expand the interface with options
# completion_model: str = "gpt-4o-mini",
completion_model: str = "gpt-4o",
embedding_model: str = "text-embedding-3-small",
azure_endpoint: str = None,
Expand All @@ -21,8 +25,8 @@ def __init__(
self.client = AzureOpenAI(
api_key=api_key,
azure_endpoint=azure_endpoint,
# What's the behavior of not pinning the API version?
# api_version="2024-10-01-preview",
# TODO: we should make this a parameter or what?
api_version="2024-10-01-preview",
)
else:
self.client = OpenAI(api_key=api_key)
Expand Down
Loading

0 comments on commit ab7b60b

Please sign in to comment.