-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add docs splitter and cross encoder based re-ranker; (#5)
Fix the formatting;
- Loading branch information
1 parent
6c6d36a
commit 6386f2b
Showing
4 changed files
with
81 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from typing import List | ||
|
||
import torch.nn.functional as F | ||
from sentence_transformers import CrossEncoder | ||
from torch import Tensor | ||
|
||
|
||
def get_scores(query: str, documents: List[str], model_name: str = "cross-encoder/ms-marco-MiniLM-L-2-v2"): | ||
"""Get the scores""" | ||
model = CrossEncoder(model_name=model_name, max_length=512) | ||
doc_tuple = [(query, doc) for doc in documents] | ||
scores = model.predict(doc_tuple) | ||
return F.softmax(Tensor(scores), dim=0).tolist() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from typing import List | ||
|
||
from langchain.schema import Document | ||
from langchain.text_splitter import SentenceTransformersTokenTextSplitter | ||
|
||
|
||
def langchain_document_to_dict(doc: Document): | ||
""" | ||
Converts langchain Document to dictionary | ||
""" | ||
return {"page_content": doc.page_content, "metadata": doc.metadata} | ||
|
||
|
||
def dict_to_langchain_document(doc: dict): | ||
""" | ||
Converts dictionary to Langchain docuemnt | ||
""" | ||
return Document(page_content=doc["page_content"], metadata=doc["metadata"]) | ||
|
||
|
||
def get_split_documents_using_token_based(model_name: str, documents: List[dict], chunk_size: int, chunk_overlap: int): | ||
""" | ||
Splits documents into multiple chunks using Sentence Transformer | ||
token based. | ||
""" | ||
splitter = SentenceTransformersTokenTextSplitter( | ||
chunk_overlap=chunk_overlap, model_name=model_name, tokens_per_chunk=chunk_size | ||
) | ||
langchain_docs = [dict_to_langchain_document(d) for d in documents] | ||
splitted_docs = splitter.split_documents(documents=langchain_docs) | ||
return [langchain_document_to_dict(d) for d in splitted_docs] |