Skip to content

Commit

Permalink
refactor: avoid writing business logic in __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Oct 26, 2023
1 parent 3ab074e commit 1c61f8a
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 241 deletions.
71 changes: 1 addition & 70 deletions server/features/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,70 +1 @@
enfrom typing import Any

from huggingface_hub import snapshot_download
from numpy import float64
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from server.features.embeddings.flag_embedding import FlagEmbedding
from server.types import ComputeTypes


class Embedding(SentenceTransformer):
"""
Summary
-------
wrapper around a SentenceTransformer which routes the forward
Methods
-------
encode_normalise(sentences: str | list[str]) -> NDArray[float64]
encode a sentence or list of sentences into a normalised embedding
"""
def __init__(
self,
*args: Any,
compute_type: ComputeTypes = 'float32',
**kwargs: Any
):

super().__init__('BAAI/bge-base-en-v1.5', *args, **kwargs)

self[0] = FlagEmbedding(
self[0],
snapshot_download('winstxnhdw/bge-base-en-v1.5-ct2', local_files_only=True),
compute_type=compute_type
)


def encode_normalise(self, sentences: str | list[str]) -> NDArray[float64]:
"""
Summary
-------
encode a sentence or list of sentences into a normalised embedding
Parameters
----------
sentences (str | list[str]) : the sentence(s) to encode
Returns
-------
embeddings (NDArray[float64]) : the normalised embeddings
"""
return self.encode(sentences, normalize_embeddings=True)


def encode_query(self, sentence: str) -> NDArray[float64]:
"""
Summary
-------
encode a sentence for searching relevant passages
Parameters
----------
sentence (str) : the sentence to encode
Returns
-------
embeddings (NDArray[float64]) : the normalised embeddings
"""
return self.encode_normalise(f'Represent this sentence for searching relevant passages: {sentence}')
from server.features.embeddings.embedding import Embedding as Embedding
70 changes: 70 additions & 0 deletions server/features/embeddings/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Any

from huggingface_hub import snapshot_download
from numpy import float64
from numpy.typing import NDArray
from sentence_transformers import SentenceTransformer

from server.features.embeddings.flag_embedding import FlagEmbedding
from server.types import ComputeTypes


class Embedding(SentenceTransformer):
"""
Summary
-------
wrapper around a SentenceTransformer which routes the forward
Methods
-------
encode_normalise(sentences: str | list[str]) -> NDArray[float64]
encode a sentence or list of sentences into a normalised embedding
"""
def __init__(
self,
*args: Any,
compute_type: ComputeTypes = 'float32',
**kwargs: Any
):

super().__init__('BAAI/bge-base-en-v1.5', *args, **kwargs)

self[0] = FlagEmbedding(
self[0],
snapshot_download('winstxnhdw/bge-base-en-v1.5-ct2', local_files_only=True),
compute_type=compute_type
)


def encode_normalise(self, sentences: str | list[str]) -> NDArray[float64]:
"""
Summary
-------
encode a sentence or list of sentences into a normalised embedding
Parameters
----------
sentences (str | list[str]) : the sentence(s) to encode
Returns
-------
embeddings (NDArray[float64]) : the normalised embeddings
"""
return self.encode(sentences, normalize_embeddings=True)


def encode_query(self, sentence: str) -> NDArray[float64]:
"""
Summary
-------
encode a sentence for searching relevant passages
Parameters
----------
sentence (str) : the sentence to encode
Returns
-------
embeddings (NDArray[float64]) : the normalised embeddings
"""
return self.encode_normalise(f'Represent this sentence for searching relevant passages: {sentence}')
62 changes: 3 additions & 59 deletions server/features/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,3 @@
from typing import Generator
from uuid import uuid4

from fastapi import UploadFile
from fitz import Document as FitzDocument

from server.features.extraction.models import Document
from server.features.extraction.models.document import Section


def extract_text(file_name: str, file_type: str, file: bytes) -> Document:
"""
Summary
-------
extract the text from a file
Parameters
----------
file_name (str): the name of the file
file (bytes): the file
file_type (str): the type of the file
Returns
-------
document (Document): the parsed document
"""
with FitzDocument(stream=file, filetype=file_type) as document:
sections = [
Section(link=f'{file_name}#{page.number}', content=page.get_text(sort=True)) # type: ignore
for page in document
]

return Document(
id=str(uuid4()),
sections=sections,
semantic_identifier=file_name
)


def extract_texts_from_requests(requests: list[UploadFile]) -> Generator[Document | None, None, None]:
"""
Summary
-------
extract the text from a list of requests
Parameters
----------
requests (list[UploadFile]): the requests to extract the text from
Yields
------
documents (Document): the parsed document
"""
for request in requests:
yield (
extract_text(*request.filename.rsplit('.', 1), file=request.file.read())
if request.filename
else None
)
from server.features.extraction.extract_text import (
extract_texts_from_requests as extract_texts_from_requests,
)
59 changes: 59 additions & 0 deletions server/features/extraction/extract_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Generator
from uuid import uuid4

from fastapi import UploadFile
from fitz import Document as FitzDocument

from server.features.extraction.models import Document
from server.features.extraction.models.document import Section


def extract_text(file_name: str, file_type: str, file: bytes) -> Document:
"""
Summary
-------
extract the text from a file
Parameters
----------
file_name (str): the name of the file
file (bytes): the file
file_type (str): the type of the file
Returns
-------
document (Document): the parsed document
"""
with FitzDocument(stream=file, filetype=file_type) as document:
sections = [
Section(link=f'{file_name}#{page.number}', content=page.get_text(sort=True)) # type: ignore
for page in document
]

return Document(
id=str(uuid4()),
sections=sections,
semantic_identifier=file_name
)


def extract_texts_from_requests(requests: list[UploadFile]) -> Generator[Document | None, None, None]:
"""
Summary
-------
extract the text from a list of requests
Parameters
----------
requests (list[UploadFile]): the requests to extract the text from
Yields
------
documents (Document): the parsed document
"""
for request in requests:
yield (
extract_text(*request.filename.rsplit('.', 1), file=request.file.read())
if request.filename
else None
)
113 changes: 1 addition & 112 deletions server/features/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,112 +1 @@
from typing import Generator, Iterable

from ctranslate2 import Generator as LLMGenerator
from transformers.models.llama import LlamaTokenizerFast

from server.features.llm.types import Message
from server.helpers import huggingface_download


class LLM:
"""
Summary
-------
a static class for generating text with an Large Language Model
Methods
-------
stop_generation()
stop the generation of text
query(messages: Iterable[Message]) -> Message | None
query the model
generate(tokens_list: Iterable[list[str]]) -> Generator[str, None, None]
generate text from a series/single prompt(s)
"""
generator: LLMGenerator
tokeniser: LlamaTokenizerFast
max_generation_length: int
max_prompt_length: int
static_prompt: list[str]

@classmethod
def load(cls):
"""
Summary
-------
download and load the language model
"""
model_path = huggingface_download('winstxnhdw/Mistral-7B-Instruct-v0.1-ct2-int8')
cls.generator = LLMGenerator(model_path, device='cpu', compute_type='auto', inter_threads=1)
cls.tokeniser = LlamaTokenizerFast.from_pretrained(model_path, local_files_only=True)

system_prompt = cls.tokeniser.apply_chat_template((
{
'content': 'You are given the following chat history. Answer the question based on the context provided as truthfully as you are able to. If you do not know the answer, you may respond with "I do not know". What is the Baloney Detection Kit?',
'role': 'user'
},
{
'content': 'The Baloney Detection Kit is a a set of cognitive tools and techniques created by Carl Sagan, that fortify the mind against penetration by falsehoods.',
'role': 'assistant'
}
), tokenize=False)

cls.static_prompt = cls.tokeniser(system_prompt).tokens()
cls.max_generation_length = 512
cls.max_prompt_length = 4096 - cls.max_generation_length - len(cls.static_prompt)


@classmethod
def query(cls, messages: Iterable[Message]) -> Message | None:
"""
Summary
-------
query the model
Parameters
----------
messages (Iterable[Message]) : the messages to query the model with
Returns
-------
answer (Message | None) : the answer to the query
"""
prompts: str = cls.tokeniser.apply_chat_template(messages, tokenize=False)
tokens = cls.tokeniser(prompts).tokens()

if len(tokens) > cls.max_prompt_length:
return None

return {
'role': 'assistant',
'content': next(cls.generate([tokens]))
}


@classmethod
def generate(cls, tokens_list: Iterable[list[str]]) -> Generator[str, None, None]:
"""
Summary
-------
generate text from a series/single prompt(s)
Parameters
----------
prompt (str) : the prompt to generate text from
Yields
-------
answer (str) : the generated answer
"""
return (
cls.tokeniser.decode(result.sequences_ids[0]) for result in cls.generator.generate_iterable(
tokens_list,
repetition_penalty=1.2,
max_length=cls.max_generation_length,
static_prompt=cls.static_prompt,
include_prompt_in_result=False,
sampling_topp=0.9,
sampling_temperature=0.9
)
)
from server.features.llm.llm import LLM as LLM
Loading

0 comments on commit 1c61f8a

Please sign in to comment.