-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: avoid writing business logic in
__init__
- Loading branch information
1 parent
3ab074e
commit 1c61f8a
Showing
6 changed files
with
246 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,70 +1 @@ | ||
enfrom typing import Any | ||
|
||
from huggingface_hub import snapshot_download | ||
from numpy import float64 | ||
from numpy.typing import NDArray | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from server.features.embeddings.flag_embedding import FlagEmbedding | ||
from server.types import ComputeTypes | ||
|
||
|
||
class Embedding(SentenceTransformer): | ||
""" | ||
Summary | ||
------- | ||
wrapper around a SentenceTransformer which routes the forward | ||
Methods | ||
------- | ||
encode_normalise(sentences: str | list[str]) -> NDArray[float64] | ||
encode a sentence or list of sentences into a normalised embedding | ||
""" | ||
def __init__( | ||
self, | ||
*args: Any, | ||
compute_type: ComputeTypes = 'float32', | ||
**kwargs: Any | ||
): | ||
|
||
super().__init__('BAAI/bge-base-en-v1.5', *args, **kwargs) | ||
|
||
self[0] = FlagEmbedding( | ||
self[0], | ||
snapshot_download('winstxnhdw/bge-base-en-v1.5-ct2', local_files_only=True), | ||
compute_type=compute_type | ||
) | ||
|
||
|
||
def encode_normalise(self, sentences: str | list[str]) -> NDArray[float64]: | ||
""" | ||
Summary | ||
------- | ||
encode a sentence or list of sentences into a normalised embedding | ||
Parameters | ||
---------- | ||
sentences (str | list[str]) : the sentence(s) to encode | ||
Returns | ||
------- | ||
embeddings (NDArray[float64]) : the normalised embeddings | ||
""" | ||
return self.encode(sentences, normalize_embeddings=True) | ||
|
||
|
||
def encode_query(self, sentence: str) -> NDArray[float64]: | ||
""" | ||
Summary | ||
------- | ||
encode a sentence for searching relevant passages | ||
Parameters | ||
---------- | ||
sentence (str) : the sentence to encode | ||
Returns | ||
------- | ||
embeddings (NDArray[float64]) : the normalised embeddings | ||
""" | ||
return self.encode_normalise(f'Represent this sentence for searching relevant passages: {sentence}') | ||
from server.features.embeddings.embedding import Embedding as Embedding |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from typing import Any | ||
|
||
from huggingface_hub import snapshot_download | ||
from numpy import float64 | ||
from numpy.typing import NDArray | ||
from sentence_transformers import SentenceTransformer | ||
|
||
from server.features.embeddings.flag_embedding import FlagEmbedding | ||
from server.types import ComputeTypes | ||
|
||
|
||
class Embedding(SentenceTransformer): | ||
""" | ||
Summary | ||
------- | ||
wrapper around a SentenceTransformer which routes the forward | ||
Methods | ||
------- | ||
encode_normalise(sentences: str | list[str]) -> NDArray[float64] | ||
encode a sentence or list of sentences into a normalised embedding | ||
""" | ||
def __init__( | ||
self, | ||
*args: Any, | ||
compute_type: ComputeTypes = 'float32', | ||
**kwargs: Any | ||
): | ||
|
||
super().__init__('BAAI/bge-base-en-v1.5', *args, **kwargs) | ||
|
||
self[0] = FlagEmbedding( | ||
self[0], | ||
snapshot_download('winstxnhdw/bge-base-en-v1.5-ct2', local_files_only=True), | ||
compute_type=compute_type | ||
) | ||
|
||
|
||
def encode_normalise(self, sentences: str | list[str]) -> NDArray[float64]: | ||
""" | ||
Summary | ||
------- | ||
encode a sentence or list of sentences into a normalised embedding | ||
Parameters | ||
---------- | ||
sentences (str | list[str]) : the sentence(s) to encode | ||
Returns | ||
------- | ||
embeddings (NDArray[float64]) : the normalised embeddings | ||
""" | ||
return self.encode(sentences, normalize_embeddings=True) | ||
|
||
|
||
def encode_query(self, sentence: str) -> NDArray[float64]: | ||
""" | ||
Summary | ||
------- | ||
encode a sentence for searching relevant passages | ||
Parameters | ||
---------- | ||
sentence (str) : the sentence to encode | ||
Returns | ||
------- | ||
embeddings (NDArray[float64]) : the normalised embeddings | ||
""" | ||
return self.encode_normalise(f'Represent this sentence for searching relevant passages: {sentence}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,59 +1,3 @@ | ||
from typing import Generator | ||
from uuid import uuid4 | ||
|
||
from fastapi import UploadFile | ||
from fitz import Document as FitzDocument | ||
|
||
from server.features.extraction.models import Document | ||
from server.features.extraction.models.document import Section | ||
|
||
|
||
def extract_text(file_name: str, file_type: str, file: bytes) -> Document: | ||
""" | ||
Summary | ||
------- | ||
extract the text from a file | ||
Parameters | ||
---------- | ||
file_name (str): the name of the file | ||
file (bytes): the file | ||
file_type (str): the type of the file | ||
Returns | ||
------- | ||
document (Document): the parsed document | ||
""" | ||
with FitzDocument(stream=file, filetype=file_type) as document: | ||
sections = [ | ||
Section(link=f'{file_name}#{page.number}', content=page.get_text(sort=True)) # type: ignore | ||
for page in document | ||
] | ||
|
||
return Document( | ||
id=str(uuid4()), | ||
sections=sections, | ||
semantic_identifier=file_name | ||
) | ||
|
||
|
||
def extract_texts_from_requests(requests: list[UploadFile]) -> Generator[Document | None, None, None]: | ||
""" | ||
Summary | ||
------- | ||
extract the text from a list of requests | ||
Parameters | ||
---------- | ||
requests (list[UploadFile]): the requests to extract the text from | ||
Yields | ||
------ | ||
documents (Document): the parsed document | ||
""" | ||
for request in requests: | ||
yield ( | ||
extract_text(*request.filename.rsplit('.', 1), file=request.file.read()) | ||
if request.filename | ||
else None | ||
) | ||
from server.features.extraction.extract_text import ( | ||
extract_texts_from_requests as extract_texts_from_requests, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Generator | ||
from uuid import uuid4 | ||
|
||
from fastapi import UploadFile | ||
from fitz import Document as FitzDocument | ||
|
||
from server.features.extraction.models import Document | ||
from server.features.extraction.models.document import Section | ||
|
||
|
||
def extract_text(file_name: str, file_type: str, file: bytes) -> Document: | ||
""" | ||
Summary | ||
------- | ||
extract the text from a file | ||
Parameters | ||
---------- | ||
file_name (str): the name of the file | ||
file (bytes): the file | ||
file_type (str): the type of the file | ||
Returns | ||
------- | ||
document (Document): the parsed document | ||
""" | ||
with FitzDocument(stream=file, filetype=file_type) as document: | ||
sections = [ | ||
Section(link=f'{file_name}#{page.number}', content=page.get_text(sort=True)) # type: ignore | ||
for page in document | ||
] | ||
|
||
return Document( | ||
id=str(uuid4()), | ||
sections=sections, | ||
semantic_identifier=file_name | ||
) | ||
|
||
|
||
def extract_texts_from_requests(requests: list[UploadFile]) -> Generator[Document | None, None, None]: | ||
""" | ||
Summary | ||
------- | ||
extract the text from a list of requests | ||
Parameters | ||
---------- | ||
requests (list[UploadFile]): the requests to extract the text from | ||
Yields | ||
------ | ||
documents (Document): the parsed document | ||
""" | ||
for request in requests: | ||
yield ( | ||
extract_text(*request.filename.rsplit('.', 1), file=request.file.read()) | ||
if request.filename | ||
else None | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,112 +1 @@ | ||
from typing import Generator, Iterable | ||
|
||
from ctranslate2 import Generator as LLMGenerator | ||
from transformers.models.llama import LlamaTokenizerFast | ||
|
||
from server.features.llm.types import Message | ||
from server.helpers import huggingface_download | ||
|
||
|
||
class LLM: | ||
""" | ||
Summary | ||
------- | ||
a static class for generating text with an Large Language Model | ||
Methods | ||
------- | ||
stop_generation() | ||
stop the generation of text | ||
query(messages: Iterable[Message]) -> Message | None | ||
query the model | ||
generate(tokens_list: Iterable[list[str]]) -> Generator[str, None, None] | ||
generate text from a series/single prompt(s) | ||
""" | ||
generator: LLMGenerator | ||
tokeniser: LlamaTokenizerFast | ||
max_generation_length: int | ||
max_prompt_length: int | ||
static_prompt: list[str] | ||
|
||
@classmethod | ||
def load(cls): | ||
""" | ||
Summary | ||
------- | ||
download and load the language model | ||
""" | ||
model_path = huggingface_download('winstxnhdw/Mistral-7B-Instruct-v0.1-ct2-int8') | ||
cls.generator = LLMGenerator(model_path, device='cpu', compute_type='auto', inter_threads=1) | ||
cls.tokeniser = LlamaTokenizerFast.from_pretrained(model_path, local_files_only=True) | ||
|
||
system_prompt = cls.tokeniser.apply_chat_template(( | ||
{ | ||
'content': 'You are given the following chat history. Answer the question based on the context provided as truthfully as you are able to. If you do not know the answer, you may respond with "I do not know". What is the Baloney Detection Kit?', | ||
'role': 'user' | ||
}, | ||
{ | ||
'content': 'The Baloney Detection Kit is a a set of cognitive tools and techniques created by Carl Sagan, that fortify the mind against penetration by falsehoods.', | ||
'role': 'assistant' | ||
} | ||
), tokenize=False) | ||
|
||
cls.static_prompt = cls.tokeniser(system_prompt).tokens() | ||
cls.max_generation_length = 512 | ||
cls.max_prompt_length = 4096 - cls.max_generation_length - len(cls.static_prompt) | ||
|
||
|
||
@classmethod | ||
def query(cls, messages: Iterable[Message]) -> Message | None: | ||
""" | ||
Summary | ||
------- | ||
query the model | ||
Parameters | ||
---------- | ||
messages (Iterable[Message]) : the messages to query the model with | ||
Returns | ||
------- | ||
answer (Message | None) : the answer to the query | ||
""" | ||
prompts: str = cls.tokeniser.apply_chat_template(messages, tokenize=False) | ||
tokens = cls.tokeniser(prompts).tokens() | ||
|
||
if len(tokens) > cls.max_prompt_length: | ||
return None | ||
|
||
return { | ||
'role': 'assistant', | ||
'content': next(cls.generate([tokens])) | ||
} | ||
|
||
|
||
@classmethod | ||
def generate(cls, tokens_list: Iterable[list[str]]) -> Generator[str, None, None]: | ||
""" | ||
Summary | ||
------- | ||
generate text from a series/single prompt(s) | ||
Parameters | ||
---------- | ||
prompt (str) : the prompt to generate text from | ||
Yields | ||
------- | ||
answer (str) : the generated answer | ||
""" | ||
return ( | ||
cls.tokeniser.decode(result.sequences_ids[0]) for result in cls.generator.generate_iterable( | ||
tokens_list, | ||
repetition_penalty=1.2, | ||
max_length=cls.max_generation_length, | ||
static_prompt=cls.static_prompt, | ||
include_prompt_in_result=False, | ||
sampling_topp=0.9, | ||
sampling_temperature=0.9 | ||
) | ||
) | ||
from server.features.llm.llm import LLM as LLM |
Oops, something went wrong.