Skip to content

Commit

Permalink
Add docs splitter and cross encoder based re-ranker; (#5)
Browse files Browse the repository at this point in the history
Fix the formatting;
  • Loading branch information
ranjan-stha authored Nov 6, 2024
1 parent 6c6d36a commit 6386f2b
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 0 deletions.
36 changes: 36 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from enum import Enum
from typing import List, Optional, Union

Expand All @@ -10,6 +11,10 @@
OpenAIEmbeddingModel,
SentenceTransformerEmbeddingModel,
)
from reranker import get_scores

# from langchain.schema import Document
from splitter import get_split_documents_using_token_based

load_dotenv()

Expand All @@ -35,6 +40,22 @@ class RequestSchemaForEmbeddings(BaseModel):
base_url: Optional[str] = None


class RequestSchemaForTextSplitter(BaseModel):
"""Request Schema"""

model: str
documents: str
chunk_size: int
chunk_overlap: int


class RequestSchemaForReRankers(BaseModel):
"""Request Schema"""

query: str
documents: List[str]


@app.get("/")
async def home():
"""Returns a message"""
Expand Down Expand Up @@ -70,3 +91,18 @@ def generate(em_model, texts):
elif type_model == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)


@app.post("/split_docs_based_on_tokens")
async def get_split_docs(item: RequestSchemaForTextSplitter):
"""Splits the documents using the model tokenization method"""
docs = json.loads(item.documents)
return get_split_documents_using_token_based(
model_name=item.model, documents=docs, chunk_size=item.chunk_size, chunk_overlap=item.chunk_overlap
)


@app.post("/docs_reranking_scores")
async def get_reranked_docs(item: RequestSchemaForReRankers):
"""Get reranked documents"""
return get_scores(item.query, item.documents)
1 change: 1 addition & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ services:
build: .
volumes:
- embedding_models:/opt/models
- .:/code
command: bash -c 'uvicorn app:app --host=0.0.0.0 --port=8000'
ports:
- "8000:8000"
Expand Down
13 changes: 13 additions & 0 deletions reranker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from typing import List

import torch.nn.functional as F
from sentence_transformers import CrossEncoder
from torch import Tensor


def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"):
"""Get the scores"""
model = CrossEncoder(model_name=model_name, max_length=512)
doc_tuple = [(query, doc) for doc in documents]
scores = model.predict(doc_tuple)
return F.softmax(Tensor(scores), dim=0).tolist()
31 changes: 31 additions & 0 deletions splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import List

from langchain.schema import Document
from langchain.text_splitter import SentenceTransformersTokenTextSplitter


def langchain_document_to_dict(doc: Document):
"""
Converts langchain Document to dictionary
"""
return {"page_content": doc.page_content, "metadata": doc.metadata}


def dict_to_langchain_document(doc: dict):
"""
Converts dictionary to Langchain docuemnt
"""
return Document(page_content=doc["page_content"], metadata=doc["metadata"])


def get_split_documents_using_token_based(model_name: str, documents: List[dict], chunk_size: int, chunk_overlap: int):
"""
Splits documents into multiple chunks using Sentence Transformer
token based.
"""
splitter = SentenceTransformersTokenTextSplitter(
chunk_overlap=chunk_overlap, model_name=model_name, tokens_per_chunk=chunk_size
)
langchain_docs = [dict_to_langchain_document(d) for d in documents]
splitted_docs = splitter.split_documents(documents=langchain_docs)
return [langchain_document_to_dict(d) for d in splitted_docs]

0 comments on commit 6386f2b

Please sign in to comment.