Skip to content

Commit

Permalink
feat: add local embeddings (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
akotyla authored and mhordynski committed Sep 17, 2024
1 parent 2f1ceac commit 5a6ebe7
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 6 deletions.
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/embeddings/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ def __init__(
api_version: The API version for the call.
Raises:
ImportError: If the litellm package is not installed.
ImportError: If the 'litellm' extra requirements are not installed.
"""
if not HAS_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")

super().__init__()
self.model = model
Expand Down
81 changes: 81 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from typing import Iterator, Optional

try:
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

HAS_LOCAL_EMBEDDINGS = True
except ImportError:
HAS_LOCAL_EMBEDDINGS = False

from ragbits.core.embeddings.base import Embeddings


class LocalEmbeddings(Embeddings):
"""
Class for interaction with any encoder available in HuggingFace.
"""

def __init__(
self,
model_name: str,
api_key: Optional[str] = None,
) -> None:
"""
Constructs a new local LLM instance.
Args:
model_name: Name of the model to use.
api_key: The API key for Hugging Face authentication.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_EMBEDDINGS:
raise ImportError("You need to install the 'local' extra requirements to use local embeddings models")

super().__init__()

self.hf_api_key = api_key
self.model_name = model_name

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key)

async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]:
"""
Calls the appropriate encoder endpoint with the given data and options.
Args:
data: List of strings to get embeddings for.
batch_size: Batch size.
Returns:
List of embeddings for the given strings.
"""
embeddings = []
for batch in self._batch(data, batch_size):
batch_dict = self.tokenizer(
batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model(**batch_dict)
batch_embeddings = self._average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1)
embeddings.extend(batch_embeddings.to("cpu").tolist())

torch.cuda.empty_cache()
return embeddings

@staticmethod
def _batch(data: list[str], batch_size: int) -> Iterator[list[str]]:
length = len(data)
for ndx in range(0, length, batch_size):
yield data[ndx : min(ndx + batch_size, length)]

@staticmethod
def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def __init__(
use_structured_output: Whether to request a structured output from the model. Default is False.
Raises:
ImportError: If the litellm package is not installed.
ImportError: If the 'litellm' extra requirements are not installed.
"""
if not HAS_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM models")

super().__init__(model_name)
self.base_url = base_url
Expand Down
3 changes: 3 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/clients/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(
Args:
model_name: Name of the model to use.
hf_api_key: The Hugging Face API key for authentication.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_LLM:
raise ImportError("You need to install the 'local' extra requirements to use local LLM models")
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-core/src/ragbits/core/llms/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def __init__(
from the model. Default is False. Can only be combined with models that support structured output.
Raises:
ImportError: If the litellm package is not installed.
ImportError: If the 'litellm' extra requirements are not installed.
"""
if not HAS_LITELLM:
raise ImportError("You need to install litellm package to use LiteLLM models")
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM models")

super().__init__(model_name, default_options)
self.base_url = base_url
Expand Down
3 changes: 3 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(
model_name: Name of the model to use. This should be a model from the CausalLM class.
default_options: Default options for the LLM.
api_key: The API key for Hugging Face authentication.
Raises:
ImportError: If the 'local' extra requirements are not installed.
"""
if not HAS_LOCAL_LLM:
raise ImportError("You need to install the 'local' extra requirements to use local LLM models")
Expand Down

0 comments on commit 5a6ebe7

Please sign in to comment.