Skip to content

Commit

Permalink
feat(api): chat the corpus + refactor schema (#57)
Browse files Browse the repository at this point in the history
* workon schema

* refactor schema serverside python

* update rust client

* genall finish
  • Loading branch information
skyl authored Nov 19, 2024
1 parent 6379ca3 commit 01c2748
Show file tree
Hide file tree
Showing 29 changed files with 593 additions and 178 deletions.
10 changes: 10 additions & 0 deletions genall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#! /bin/bash
set -e

pushd py
./genall.sh
popd

pushd rs
./genall.sh
popd
66 changes: 64 additions & 2 deletions py/packages/corpora/routers/corpus.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, List, Optional
import uuid
from typing import Dict, List, Optional

from django.db import IntegrityError
from django.core.exceptions import ValidationError
Expand All @@ -11,15 +11,25 @@
from asgiref.sync import sync_to_async
from pydantic import BaseModel

from corpora.schema.chat import CorpusChatSchema, get_additional_context
from corpora_ai.llm_interface import ChatCompletionTextMessage
from corpora_ai.provider_loader import load_llm_provider

from ..auth import BearerAuth
from ..lib.dj.decorators import async_raise_not_found
from ..models import Corpus
from ..schema import CorpusSchema, CorpusResponseSchema
from ..schema.core import CorpusSchema, CorpusResponseSchema
from ..tasks.sync import process_tarball

corpus_router = Router(tags=["corpus"], auth=BearerAuth())


CHAT_SYSTEM_MESSAGE = (
"Use the context provided to generate a response to the user. "
"Be imaginative and creative, but stay within the context. "
)


class CorpusUpdateFilesSchema(BaseModel):
delete_files: Optional[List[str]] = None

Expand Down Expand Up @@ -116,3 +126,55 @@ async def get_corpus(request, corpus_id: uuid.UUID):
"""Retrieve a Corpus by ID."""
corpus = await Corpus.objects.aget(id=corpus_id)
return corpus


@corpus_router.post(
"/chat",
response=str,
operation_id="chat",
)
@async_raise_not_found
async def chat(request, payload: CorpusChatSchema):
"""Chat with the Corpus."""
corpus = await Corpus.objects.aget(id=payload.corpus_id)
# TODO: last 2 messages? Eventually we need to worry about
# token count limits.
# Ideally we might roll-up a summary of the entire conversation.
# But, in the current design, we let the client decide the messages.
# A separate endpoint could be used by the client to "compress conversation"
split_context = await sync_to_async(corpus.get_relevant_splits_context)(
"\n".join(message.text for message in payload.messages[-2:])
)

print(payload.messages[-1].text)

all_messages = [
ChatCompletionTextMessage(
role="system",
text=f"You are focused on the file: {payload.path} "
f"in the {corpus.name} corpus. "
f"{CHAT_SYSTEM_MESSAGE}"
f"{get_additional_context(payload)}",
),
# Alternatively we use multiple system messages?
# ChatCompletionTextMessage(role="system", text=VOICE_TEXT),
# .corpora/VOICE.md
# .corpora/PURPOSE.md
# .corpora/STRUCTURE.md
# .corpora/{ext}/DIRECTIONS.md
ChatCompletionTextMessage(
role="user",
text=(
f"I searched the broader corpus and found the following context:\n"
f"---\n{split_context}\n---"
),
),
*[
ChatCompletionTextMessage(role=msg.role, text=msg.text)
for msg in payload.messages
],
]

llm = load_llm_provider()
resp = llm.get_text_completion(all_messages)
return resp
2 changes: 1 addition & 1 deletion py/packages/corpora/routers/corpustextfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..lib.files import compute_checksum
from ..lib.dj.decorators import async_raise_not_found
from ..models import Corpus, CorpusTextFile
from ..schema import FileSchema, FileResponseSchema
from ..schema.core import FileSchema, FileResponseSchema
from ..auth import BearerAuth

file_router = Router(tags=["file"], auth=BearerAuth())
Expand Down
48 changes: 2 additions & 46 deletions py/packages/corpora/routers/plan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List
from ninja import Router, Schema
from asgiref.sync import sync_to_async

from corpora.schema.chat import CorpusChatSchema, get_additional_context
from corpora_ai.llm_interface import ChatCompletionTextMessage
from corpora_ai.provider_loader import load_llm_provider
from corpora.auth import BearerAuth
Expand All @@ -26,55 +26,11 @@ class IssueSchema(Schema):
body: str


class MessageSchema(Schema):
role: str # e.g., "user", "system", "assistant"
text: str


# TODO: DRY this out with workon.py? It's a bit different tho.
class IssueRequestSchema(Schema):
corpus_id: str
messages: List[MessageSchema]
voice: str = ""
purpose: str = ""
structure: str = ""
directions: str = ""


def get_additional_context(payload: IssueRequestSchema) -> str:
# TODO: more automatically expandable implementation
# without the ifs
context = ""
if any(
[
payload.voice,
payload.purpose,
payload.structure,
payload.directions,
]
):
context += "\n\nADDITIONAL CONTEXT:\n\n"

if payload.voice:
context += f"VOICE:\n\n{payload.voice}\n\n"

if payload.purpose:
context += f"PURPOSE of corpus:\n\n{payload.purpose}\n\n"

if payload.structure:
context += f"STRUCTURE of corpus:\n\n{payload.structure}\n\n"

if payload.directions:
context += f"DIRECTIONS for issue:\n\n{payload.directions}\n\n"

return context


plan_router = Router(tags=["plan"], auth=BearerAuth())


@plan_router.post("/issue", response=IssueSchema, operation_id="get_issue")
async def get_issue(request, payload: IssueRequestSchema):
async def get_issue(request, payload: CorpusChatSchema):
corpus = await Corpus.objects.aget(id=payload.corpus_id)

# TODO: split context could be ... ?
Expand Down
2 changes: 1 addition & 1 deletion py/packages/corpora/routers/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from asgiref.sync import sync_to_async

from ..models import Corpus, Split
from ..schema import SplitResponseSchema, SplitVectorSearchSchema
from ..schema.core import SplitResponseSchema, SplitVectorSearchSchema
from ..auth import BearerAuth

split_router = Router(tags=["split"], auth=BearerAuth())
Expand Down
56 changes: 2 additions & 54 deletions py/packages/corpora/routers/workon.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List
from ninja import Router, Schema
from asgiref.sync import sync_to_async

from corpora.schema.chat import CorpusFileChatSchema, get_additional_context
from corpora_ai.llm_interface import ChatCompletionTextMessage
from corpora_ai.provider_loader import load_llm_provider
from corpora.models import Corpus
Expand All @@ -11,22 +11,6 @@
workon_router = Router(tags=["workon"], auth=BearerAuth())


class MessageSchema(Schema):
role: str # e.g., "user", "system", "assistant"
text: str


class CorpusFileChatSchema(Schema):
corpus_id: str
messages: List[MessageSchema]
path: str
# optional additional context: voice, purpose, structure, directions
voice: str = ""
purpose: str = ""
structure: str = ""
directions: str = ""


FILE_EDITOR_SYSTEM_MESSAGE = (
"You are editing the file and must return only the new revision of the file. "
"Do not include any additional context, explanations, or surrounding text. "
Expand All @@ -38,36 +22,6 @@ class FileRevisionResponse(Schema):
new_file_revision: str


def get_additional_context(payload: CorpusFileChatSchema) -> str:
# TODO: more automatically expandable implementation
# without the ifs
context = ""
if any(
[
payload.voice,
payload.purpose,
payload.structure,
payload.directions,
]
):
context += "\n\nADDITIONAL CONTEXT:\n\n"

if payload.voice:
context += f"VOICE:\n\n{payload.voice}\n\n"

if payload.purpose:
context += f"PURPOSE of corpus:\n\n{payload.purpose}\n\n"

if payload.structure:
context += f"STRUCTURE of corpus:\n\n{payload.structure}\n\n"

if payload.directions:
ext = payload.path.split(".")[-1]
context += f"DIRECTIONS for {ext} filetype:\n\n{payload.directions}\n\n"

return context


@workon_router.post("/file", response=str, operation_id="file")
async def file(request, payload: CorpusFileChatSchema):
corpus = await Corpus.objects.aget(id=payload.corpus_id)
Expand All @@ -89,14 +43,8 @@ async def file(request, payload: CorpusFileChatSchema):
text=f"You are focused on the file: {payload.path} "
f"in the {corpus.name} corpus. "
f"{FILE_EDITOR_SYSTEM_MESSAGE}"
f"{get_additional_context(payload)}",
f"{get_additional_context(payload, ext=payload.path.split('.')[-1])}",
),
# Alternatively we use multiple system messages?
# ChatCompletionTextMessage(role="system", text=VOICE_TEXT),
# .corpora/VOICE.md
# .corpora/PURPOSE.md
# .corpora/STRUCTURE.md
# .corpora/{ext}/DIRECTIONS.md
ChatCompletionTextMessage(
role="user",
text=(
Expand Down
Empty file.
50 changes: 50 additions & 0 deletions py/packages/corpora/schema/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List
from ninja import Schema


class MessageSchema(Schema):
role: str # e.g., "user", "system", "assistant"
text: str


class CorpusChatSchema(Schema):
corpus_id: str
messages: List[MessageSchema]
# optional additional context: voice, purpose, structure, directions
voice: str = ""
purpose: str = ""
structure: str = ""
directions: str = ""


class CorpusFileChatSchema(CorpusChatSchema):
path: str


def get_additional_context(payload: CorpusChatSchema, ext: str = "") -> str:
# TODO: more automatically expandable implementation
# without the ifs
context = ""
if any(
[
payload.voice,
payload.purpose,
payload.structure,
payload.directions,
]
):
context += "\n\nADDITIONAL CONTEXT:\n\n"

if payload.voice:
context += f"VOICE:\n\n{payload.voice}\n\n"

if payload.purpose:
context += f"PURPOSE of corpus:\n\n{payload.purpose}\n\n"

if payload.structure:
context += f"STRUCTURE of corpus:\n\n{payload.structure}\n\n"

if payload.directions:
context += f"DIRECTIONS for {ext} filetype:\n\n{payload.directions}\n\n"

return context
File renamed without changes.
6 changes: 3 additions & 3 deletions py/packages/corpora_cli/commands/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import typer
from prompt_toolkit.shortcuts import PromptSession

from corpora_client.models.issue_request_schema import IssueRequestSchema
from corpora_client.models.corpus_chat_schema import CorpusChatSchema
from corpora_client.models.message_schema import MessageSchema
from corpora_pm.providers.provider_loader import Corpus, load_provider
from corpora_cli.context import ContextObject
Expand Down Expand Up @@ -66,7 +66,7 @@ def issue(ctx: typer.Context):
directions = get_file_content_or_create(".corpora/md/DIRECTIONS.md")

draft_issue = c.plan_api.get_issue(
IssueRequestSchema(
CorpusChatSchema(
messages=messages,
corpus_id=c.config["id"],
voice=voice,
Expand Down Expand Up @@ -158,7 +158,7 @@ def update_issue(ctx: typer.Context, issue_number: int):
directions = get_file_content_or_create(".corpora/md/DIRECTIONS.md")

updated_issue = c.plan_api.get_issue(
IssueRequestSchema(
CorpusChatSchema(
messages=messages,
corpus_id=c.config["id"],
voice=voice,
Expand Down
15 changes: 7 additions & 8 deletions py/packages/corpora_client/README.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 01c2748

Please sign in to comment.