diff --git a/examples/document-search/images/bear.jpg b/examples/document-search/images/bear.jpg new file mode 100644 index 000000000..ee107a3ed Binary files /dev/null and b/examples/document-search/images/bear.jpg differ diff --git a/examples/document-search/images/game.jpg b/examples/document-search/images/game.jpg new file mode 100644 index 000000000..143d6959a Binary files /dev/null and b/examples/document-search/images/game.jpg differ diff --git a/examples/document-search/images/tree.jpg b/examples/document-search/images/tree.jpg new file mode 100644 index 000000000..a087729d1 Binary files /dev/null and b/examples/document-search/images/tree.jpg differ diff --git a/examples/document-search/multimodal.py b/examples/document-search/multimodal.py new file mode 100644 index 000000000..88dfb5710 --- /dev/null +++ b/examples/document-search/multimodal.py @@ -0,0 +1,98 @@ +""" +Ragbits Document Search Example: Multimodal Embeddings + +This example demonstrates how to use the `DocumentSearch` to index and search for images and text documents. + +It employes the "multimodalembedding" from VertexAI. In order to use it, make sure that you are +logged in to Google Cloud (using the `gcloud auth login` command) and that you have the necessary permissions. + +The script performs the following steps: + 1. Create a list of example documents. + 2. Initialize the `VertexAIMultimodelEmbeddings` class (which uses the VertexAI multimodal embeddings). + 3. Initialize the `InMemoryVectorStore` class, which stores the embeddings for the duration of the script. + 4. Initialize the `DocumentSearch` class with the embedder and the vector store. + 5. Ingest the documents into the `DocumentSearch` instance. + 6. List all embeddings in the vector store. + 7. Search for documents using a query. + 8. Print the search results. + +To run the script, execute the following command: + + ```bash + uv run python examples/document-search/multimodal.py + ``` +""" + +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "ragbits-document-search", +# "ragbits-core[litellm]", +# ] +# /// +import asyncio +from pathlib import Path + +from ragbits.core.embeddings.vertex_multimodal import VertexAIMultimodelEmbeddings +from ragbits.core.vector_stores.in_memory import InMemoryVectorStore +from ragbits.document_search import DocumentSearch +from ragbits.document_search.documents.document import DocumentMeta, DocumentType +from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter +from ragbits.document_search.ingestion.providers.dummy import DummyImageProvider + +IMAGES_PATH = Path(__file__).parent / "images" + + +def jpg_example(file_name: str) -> DocumentMeta: + """ + Create a document from a JPG file in the images directory. + """ + return DocumentMeta(document_type=DocumentType.JPG, source=LocalFileSource(path=IMAGES_PATH / file_name)) + + +documents = [ + jpg_example("bear.jpg"), + jpg_example("game.jpg"), + jpg_example("tree.jpg"), + DocumentMeta.create_text_document_from_literal("A beautiful teady bear."), + DocumentMeta.create_text_document_from_literal("The constitution of the United States."), +] + + +async def main() -> None: + """ + Run the example. + """ + embedder = VertexAIMultimodelEmbeddings() + vector_store = InMemoryVectorStore() + router = DocumentProcessorRouter.from_config( + { + # For this example, we want to skip OCR and make sure + # that we test direct image embeddings. + DocumentType.JPG: DummyImageProvider(), + } + ) + + document_search = DocumentSearch( + embedder=embedder, + vector_store=vector_store, + document_processor_router=router, + ) + + await document_search.ingest(documents) + + all_embeddings = await vector_store.list() + for embedding in all_embeddings: + print(f"Embedding: {embedding.metadata['document_meta']}") + print() + + results = await document_search.search("Fluffy teady bear") + print("Results for 'Fluffy teady bear toy':") + for result in results: + document = await result.document_meta.fetch() + print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.get_text_representation()}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/base.py b/packages/ragbits-core/src/ragbits/core/embeddings/base.py index ede4fcadf..66c2716dd 100644 --- a/packages/ragbits-core/src/ragbits/core/embeddings/base.py +++ b/packages/ragbits-core/src/ragbits/core/embeddings/base.py @@ -17,3 +17,24 @@ async def embed_text(self, data: list[str]) -> list[list[float]]: Returns: List of embeddings for the given strings. """ + + def image_support(self) -> bool: # noqa: PLR6301 + """ + Check if the model supports image embeddings. + + Returns: + True if the model supports image embeddings, False otherwise. + """ + return False + + async def embed_image(self, images: list[bytes]) -> list[list[float]]: + """ + Creates embeddings for the given images. + + Args: + images: List of images to get embeddings for. + + Returns: + List of embeddings for the given images. + """ + raise NotImplementedError("Image embeddings are not supported by this model.") diff --git a/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py b/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py new file mode 100644 index 000000000..8f4bf422c --- /dev/null +++ b/packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py @@ -0,0 +1,170 @@ +import asyncio +import base64 + +try: + import litellm + from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError + from litellm.main import VertexMultimodalEmbedding + + HAS_LITELLM = True +except ImportError: + HAS_LITELLM = False + +from ragbits.core.audit import trace +from ragbits.core.embeddings import Embeddings +from ragbits.core.embeddings.exceptions import ( + EmbeddingResponseError, + EmbeddingStatusError, +) + + +class VertexAIMultimodelEmbeddings(Embeddings): + """ + Client for creating text embeddings using LiteLLM API. + """ + + VERTEX_AI_PREFIX = "vertex_ai/" + + def __init__( + self, + model: str = "multimodalembedding", + api_base: str | None = None, + api_key: str | None = None, + concurency: int = 10, + options: dict | None = None, + ) -> None: + """ + Constructs the embedding client for multimodal VertexAI models. + + Args: + model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding". + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. If not specified, an environment variable will be used. + concurency: The number of concurrent requests to make to the API. + options: Additional options to pass to the API. + + Raises: + ImportError: If the 'litellm' extra requirements are not installed. + ValueError: If the chosen model is not supported by VertexAI multimodal embeddings. + """ + if not HAS_LITELLM: + raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models") + + super().__init__() + if model.startswith(self.VERTEX_AI_PREFIX): + model = model[len(self.VERTEX_AI_PREFIX) :] + + self.model = model + self.api_base = api_base + self.api_key = api_key + self.concurency = concurency + self.options = options or {} + + supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS + if model not in supported_models: + raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings") + + async def _embed(self, data: list[dict]) -> list[dict]: + """ + Creates embeddings for the given data. The format is defined in the VertexAI API: + https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings + + Args: + data: List of instances in the format expected by the VertexAI API. + + Returns: + List of embeddings for the given VertexAI instances, each instance is a dictionary + in the format returned by the VertexAI API. + + Raises: + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + with trace( + data=data, + model=self.model, + api_base=self.api_base, + options=self.options, + ) as outputs: + semaphore = asyncio.Semaphore(self.concurency) + try: + response = await asyncio.gather( + *[self._call_litellm(instance, semaphore) for instance in data], + ) + except VertexAIError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + + outputs.embeddings = [] + for i, embedding in enumerate(response): + if embedding.data is None or not embedding.data: + raise EmbeddingResponseError(f"No embeddings returned for instance {i}") + outputs.embeddings.append(embedding.data[0]) + + return outputs.embeddings + + async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> litellm.EmbeddingResponse: + """ + Calls the LiteLLM API to get embeddings for the given data. + + Args: + instance: Single VertexAI instance to get embeddings for. + semaphore: Semaphore to limit the number of concurrent requests. + + Returns: + List of embeddings for the given LiteLLM instances. + """ + async with semaphore: + response = await litellm.aembedding( + input=[instance], + model=f"{self.VERTEX_AI_PREFIX}{self.model}", + api_base=self.api_base, + api_key=self.api_key, + **self.options, + ) + + return response + + async def embed_text(self, data: list[str]) -> list[list[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + response = await self._embed([{"text": text} for text in data]) + return [embedding["textEmbedding"] for embedding in response] + + def image_support(self) -> bool: # noqa: PLR6301 + """ + Check if the model supports image embeddings. + + Returns: + True if the model supports image embeddings, False otherwise. + """ + return True + + async def embed_image(self, images: list[bytes]) -> list[list[float]]: + """ + Creates embeddings for the given images. + + Args: + images: List of images to get embeddings for. + + Returns: + List of embeddings for the given images. + + Raises: + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + images_b64 = (base64.b64encode(image).decode() for image in images) + response = await self._embed([{"image": {"bytesBase64Encoded": image}} for image in images_b64]) + + return [embedding["imageEmbedding"] for embedding in response] diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index 75cd8390b..521ea0ff5 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -1,3 +1,4 @@ +import warnings from collections.abc import Sequence from typing import Any @@ -8,7 +9,7 @@ from ragbits.core.vector_stores import VectorStore, get_vector_store from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.document_search.documents.document import Document, DocumentMeta -from ragbits.document_search.documents.element import Element +from ragbits.document_search.documents.element import Element, ImageElement from ragbits.document_search.documents.sources import Source from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.base import BaseProvider @@ -170,5 +171,22 @@ async def insert_elements(self, elements: list[Element]) -> None: elements: The list of Elements to insert. """ vectors = await self.embedder.embed_text([element.get_text_for_embedding() for element in elements]) + + image_elements = [element for element in elements if isinstance(element, ImageElement)] entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)] + + if image_elements and self.embedder.image_support(): + image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements]) + entries.extend( + [ + element.to_vector_db_entry(vector) + for element, vector in zip(image_elements, image_vectors, strict=False) + ] + ) + elif image_elements: + warnings.warn( + f"Image elements are not supported by the embedder {self.embedder}. " + f"Skipping {len(image_elements)} image elements." + ) + await self.vector_store.store(entries) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py index 21b080f2d..893d7d3cb 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/element.py @@ -41,6 +41,7 @@ def id(self) -> str: """ id_components = [ self.document_meta.id, + self.element_type, self.get_text_for_embedding(), self.get_text_representation(), str(self.location), diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py index 09fe917c7..d65e84b22 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/providers/dummy.py @@ -3,7 +3,7 @@ DocumentType, TextDocument, ) -from ragbits.document_search.documents.element import Element, TextElement +from ragbits.document_search.documents.element import Element, ImageElement, TextElement from ragbits.document_search.ingestion.providers.base import BaseProvider @@ -31,3 +31,37 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]: if isinstance(document, TextDocument): return [TextElement(content=document.content, document_meta=document_meta)] return [] + + +class DummyImageProvider(BaseProvider): + """ + This is a simple provider that returns an ImageElement with the content of the image + and empty text metadata. + """ + + SUPPORTED_DOCUMENT_TYPES = {DocumentType.JPG, DocumentType.PNG} + + async def process(self, document_meta: DocumentMeta) -> list[Element]: + """ + Process the image document. + + Args: + document_meta: The document to process. + + Returns: + List with a single ImageElement containing the content of the image. + """ + self.validate_document_type(document_meta.document_type) + + document = await document_meta.fetch() + image_path = document.local_path + with open(image_path, "rb") as f: + image_bytes = f.read() + return [ + ImageElement( + description="", + ocr_extracted_text="", + image_bytes=image_bytes, + document_meta=document_meta, + ) + ] diff --git a/uv.lock b/uv.lock index 65fe62007..a3b1b1b87 100644 --- a/uv.lock +++ b/uv.lock @@ -2611,7 +2611,7 @@ name = "nvidia-cudnn-cu12" version = "8.9.2.26" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ff/74/a2e2be7fb83aaedec84f391f082cf765dfb635e7caa9b49065f73e4835d8/nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9", size = 731725872 }, @@ -2638,9 +2638,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2651,7 +2651,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -4892,7 +4892,7 @@ name = "triton" version = "2.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "python_full_version < '3.12'" }, + { name = "filelock", marker = "(python_full_version < '3.12' and platform_machine != 'aarch64' and platform_system != 'Darwin') or (python_full_version < '3.12' and platform_system != 'Darwin' and platform_system != 'Linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/95/05/ed974ce87fe8c8843855daa2136b3409ee1c126707ab54a8b72815c08b49/triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5", size = 167900779 },