-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(routers): reorg, expand, start split search, 77 tests (#16)
* think about TODO next a bit * create expandable routers namespace, reorg, test and generate * start into CLI refactor * adding tests to CLI * add corpus tests * basic coverage boilerplate * test models, rm old api * some .env QoL
- Loading branch information
Showing
46 changed files
with
3,281 additions
and
1,126 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
# .corpora.yaml | ||
|
||
name: "corpora2" | ||
name: "corpora" | ||
url: "https://github.com/skyl/corpora" | ||
|
||
server: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from ninja import Router | ||
|
||
from .auth import BearerAuth | ||
from .routers.corpus import corpus_router | ||
from .routers.corpustextfile import file_router | ||
from .routers.split import split_router | ||
|
||
api = Router(tags=["corpora"], auth=BearerAuth()) | ||
|
||
api.add_router("corpus", corpus_router) | ||
api.add_router("file", file_router) | ||
api.add_router("split", split_router) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from typing import List | ||
import uuid | ||
|
||
from django.db import IntegrityError | ||
from django.core.exceptions import ValidationError | ||
from ninja import Router, Form, File | ||
from ninja.files import UploadedFile | ||
from ninja.errors import HttpError | ||
from asgiref.sync import sync_to_async | ||
|
||
from ..auth import BearerAuth | ||
from ..lib.dj.decorators import async_raise_not_found | ||
from ..models import Corpus | ||
from ..schema import CorpusSchema, CorpusResponseSchema | ||
from ..tasks import process_tarball | ||
|
||
corpus_router = Router(tags=["corpus"], auth=BearerAuth()) | ||
|
||
|
||
@corpus_router.post( | ||
"", | ||
response={201: CorpusResponseSchema, 400: str, 409: str}, | ||
operation_id="create_corpus", | ||
) | ||
async def create_corpus( | ||
request, | ||
corpus: CorpusSchema = Form(...), | ||
tarball: UploadedFile = File(...), | ||
): | ||
"""Create a new Corpus with an uploaded tarball.""" | ||
tarball_content: bytes = await sync_to_async(tarball.read)() | ||
try: | ||
corpus_instance = await Corpus.objects.acreate( | ||
name=corpus.name, | ||
url=corpus.url, | ||
owner=request.user, | ||
) | ||
except IntegrityError: | ||
raise HttpError(409, "A corpus with this name already exists for this owner.") | ||
except ValidationError: | ||
raise HttpError(400, "Invalid data provided.") | ||
|
||
process_tarball.delay(str(corpus_instance.id), tarball_content) | ||
return 201, corpus_instance | ||
|
||
|
||
@corpus_router.delete("", response={204: str, 404: str}, operation_id="delete_corpus") | ||
@async_raise_not_found | ||
async def delete_corpus(request, corpus_name: str): | ||
"""Delete a Corpus by name.""" | ||
corpus = await Corpus.objects.aget(owner=request.user, name=corpus_name) | ||
await sync_to_async(corpus.delete)() | ||
return 204, "Corpus deleted." | ||
|
||
|
||
@corpus_router.get( | ||
"", response={200: List[CorpusResponseSchema]}, operation_id="list_corpora" | ||
) | ||
async def list_corpora(request): | ||
"""List all Corpora.""" | ||
corpora = await sync_to_async(list)(Corpus.objects.filter(owner=request.user)) | ||
return corpora | ||
|
||
|
||
@corpus_router.get( | ||
"/{corpus_id}", response=CorpusResponseSchema, operation_id="get_corpus" | ||
) | ||
@async_raise_not_found | ||
async def get_corpus(request, corpus_id: uuid.UUID): | ||
"""Retrieve a Corpus by ID.""" | ||
corpus = await Corpus.objects.aget(id=corpus_id) | ||
return corpus |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import uuid | ||
|
||
from django.db import IntegrityError | ||
from ninja import Router | ||
from ninja.errors import HttpError | ||
|
||
from ..lib.files import compute_checksum | ||
from ..lib.dj.decorators import async_raise_not_found | ||
from ..models import Corpus, CorpusTextFile | ||
from ..schema import FileSchema, FileResponseSchema | ||
from ..auth import BearerAuth | ||
|
||
file_router = Router(tags=["file"], auth=BearerAuth()) | ||
|
||
|
||
@file_router.post( | ||
"", response={201: FileResponseSchema, 409: str}, operation_id="create_file" | ||
) | ||
@async_raise_not_found | ||
async def create_file(request, payload: FileSchema): | ||
"""Create a new File within a Corpus.""" | ||
corpus = await Corpus.objects.aget(id=payload.corpus_id) | ||
checksum = compute_checksum(payload.content) | ||
try: | ||
file = await CorpusTextFile.objects.acreate( | ||
path=payload.path, | ||
content=payload.content, | ||
checksum=checksum, | ||
corpus=corpus, | ||
) | ||
except IntegrityError: | ||
# Handle the unique constraint violation for duplicate file paths within the same corpus | ||
raise HttpError(409, "A file with this path already exists in the corpus.") | ||
|
||
return 201, file | ||
|
||
|
||
@file_router.get("/{file_id}", response=FileResponseSchema, operation_id="get_file") | ||
@async_raise_not_found | ||
async def get_file(request, file_id: uuid.UUID): | ||
"""Retrieve a File by ID.""" | ||
file = await CorpusTextFile.objects.select_related("corpus").aget(id=file_id) | ||
return file |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
from typing import List | ||
import uuid | ||
|
||
from ninja import Router | ||
from pgvector.django import CosineDistance | ||
from asgiref.sync import sync_to_async | ||
|
||
from corpora_ai.provider_loader import load_llm_provider | ||
from ..models import Split | ||
from ..schema import SplitResponseSchema, SplitVectorSearchSchema | ||
from ..auth import BearerAuth | ||
|
||
split_router = Router(tags=["split"], auth=BearerAuth()) | ||
|
||
|
||
@split_router.post( | ||
"/search", response=List[SplitResponseSchema], operation_id="vector_search" | ||
) | ||
async def vector_search(request, payload: SplitVectorSearchSchema): | ||
"""Perform a vector similarity search for splits using a provided query vector.""" | ||
query = payload.text | ||
corpus_id = payload.corpus_id | ||
|
||
llm = load_llm_provider() | ||
query_vector = llm.get_embedding(query) | ||
|
||
# Using cosine similarity for the search | ||
similar_splits = await sync_to_async(list)( | ||
Split.objects.filter( | ||
vector__isnull=False, | ||
file__corpus_id=corpus_id, | ||
) | ||
.annotate(similarity=CosineDistance("vector", query_vector)) | ||
.order_by("similarity")[: payload.limit] | ||
) | ||
|
||
return similar_splits | ||
|
||
|
||
@split_router.get("/{split_id}", response=SplitResponseSchema, operation_id="get_split") | ||
async def get_split(request, split_id: uuid.UUID): | ||
"""Retrieve a Split by ID.""" | ||
split = await Split.objects.select_related("file", "file__corpus").aget(id=split_id) | ||
return split | ||
|
||
|
||
@split_router.get( | ||
"/file/{file_id}", | ||
response=List[SplitResponseSchema], | ||
operation_id="list_splits_for_file", | ||
) | ||
async def list_splits_for_file(request, file_id: uuid.UUID): | ||
"""List all Splits for a specific CorpusTextFile.""" | ||
splits = await sync_to_async(list)( | ||
Split.objects.filter(file_id=file_id).order_by("order") | ||
) | ||
return splits |
Oops, something went wrong.