diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py index 1d8fb9d8..ca9dcb6c 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/litellm.py @@ -38,10 +38,10 @@ def __init__( api_version: The API version for the call. Raises: - ImportError: If the litellm package is not installed. + ImportError: If the 'litellm' extra requirements are not installed. """ if not HAS_LITELLM: - raise ImportError("You need to install litellm package to use LiteLLM models") + raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models") super().__init__() self.model = model diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/local.py b/packages/ragbits-core/src/ragbits/core/embeddings/local.py new file mode 100644 index 00000000..8a4f52ba --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/embeddings/local.py @@ -0,0 +1,81 @@ +from typing import Iterator, Optional + +try: + import torch + import torch.nn.functional as F + from transformers import AutoModel, AutoTokenizer + + HAS_LOCAL_EMBEDDINGS = True +except ImportError: + HAS_LOCAL_EMBEDDINGS = False + +from ragbits.core.embeddings.base import Embeddings + + +class LocalEmbeddings(Embeddings): + """ + Class for interaction with any encoder available in HuggingFace. + """ + + def __init__( + self, + model_name: str, + api_key: Optional[str] = None, + ) -> None: + """ + Constructs a new local LLM instance. + + Args: + model_name: Name of the model to use. + api_key: The API key for Hugging Face authentication. + + Raises: + ImportError: If the 'local' extra requirements are not installed. + """ + if not HAS_LOCAL_EMBEDDINGS: + raise ImportError("You need to install the 'local' extra requirements to use local embeddings models") + + super().__init__() + + self.hf_api_key = api_key + self.model_name = model_name + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.model = AutoModel.from_pretrained(self.model_name, token=self.hf_api_key).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, token=self.hf_api_key) + + async def embed_text(self, data: list[str], batch_size: int = 1) -> list[list[float]]: + """ + Calls the appropriate encoder endpoint with the given data and options. + + Args: + data: List of strings to get embeddings for. + batch_size: Batch size. + + Returns: + List of embeddings for the given strings. + """ + embeddings = [] + for batch in self._batch(data, batch_size): + batch_dict = self.tokenizer( + batch, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + outputs = self.model(**batch_dict) + batch_embeddings = self._average_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) + batch_embeddings = F.normalize(batch_embeddings, p=2, dim=1) + embeddings.extend(batch_embeddings.to("cpu").tolist()) + + torch.cuda.empty_cache() + return embeddings + + @staticmethod + def _batch(data: list[str], batch_size: int) -> Iterator[list[str]]: + length = len(data) + for ndx in range(0, length, batch_size): + yield data[ndx : min(ndx + batch_size, length)] + + @staticmethod + def _average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) + return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py index b504c9de..f1620d8c 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/litellm.py @@ -64,10 +64,10 @@ def __init__( use_structured_output: Whether to request a structured output from the model. Default is False. Raises: - ImportError: If the litellm package is not installed. + ImportError: If the 'litellm' extra requirements are not installed. """ if not HAS_LITELLM: - raise ImportError("You need to install litellm package to use LiteLLM models") + raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM models") super().__init__(model_name) self.base_url = base_url diff --git a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py index 50c4fc54..d3a1d0f6 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/clients/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/clients/local.py @@ -55,6 +55,9 @@ def __init__( Args: model_name: Name of the model to use. hf_api_key: The Hugging Face API key for authentication. + + Raises: + ImportError: If the 'local' extra requirements are not installed. """ if not HAS_LOCAL_LLM: raise ImportError("You need to install the 'local' extra requirements to use local LLM models") diff --git a/packages/ragbits-core/src/ragbits/core/llms/litellm.py b/packages/ragbits-core/src/ragbits/core/llms/litellm.py index 86bb4d07..00524113 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/litellm.py +++ b/packages/ragbits-core/src/ragbits/core/llms/litellm.py @@ -48,10 +48,10 @@ def __init__( from the model. Default is False. Can only be combined with models that support structured output. Raises: - ImportError: If the litellm package is not installed. + ImportError: If the 'litellm' extra requirements are not installed. """ if not HAS_LITELLM: - raise ImportError("You need to install litellm package to use LiteLLM models") + raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM models") super().__init__(model_name, default_options) self.base_url = base_url diff --git a/packages/ragbits-core/src/ragbits/core/llms/local.py b/packages/ragbits-core/src/ragbits/core/llms/local.py index 1aa5b99c..cf3cacbe 100644 --- a/packages/ragbits-core/src/ragbits/core/llms/local.py +++ b/packages/ragbits-core/src/ragbits/core/llms/local.py @@ -35,6 +35,9 @@ def __init__( model_name: Name of the model to use. This should be a model from the CausalLM class. default_options: Default options for the LLM. api_key: The API key for Hugging Face authentication. + + Raises: + ImportError: If the 'local' extra requirements are not installed. """ if not HAS_LOCAL_LLM: raise ImportError("You need to install the 'local' extra requirements to use local LLM models")