Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): Support for ingesting images #172

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added examples/document-search/images/bear.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/document-search/images/game.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/document-search/images/tree.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
98 changes: 98 additions & 0 deletions examples/document-search/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Ragbits Document Search Example: Multimodal Embeddings

This example demonstrates how to use the `DocumentSearch` to index and search for images and text documents.

It employes the "multimodalembedding" from VertexAI. In order to use it, make sure that you are
logged in to Google Cloud (using the `gcloud auth login` command) and that you have the necessary permissions.

The script performs the following steps:
1. Create a list of example documents.
2. Initialize the `VertexAIMultimodelEmbeddings` class (which uses the VertexAI multimodal embeddings).
3. Initialize the `InMemoryVectorStore` class, which stores the embeddings for the duration of the script.
4. Initialize the `DocumentSearch` class with the embedder and the vector store.
5. Ingest the documents into the `DocumentSearch` instance.
6. List all embeddings in the vector store.
7. Search for documents using a query.
8. Print the search results.

To run the script, execute the following command:

```bash
uv run python examples/document-search/multimodal.py
```
"""

# /// script
micpst marked this conversation as resolved.
Show resolved Hide resolved
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits-core[litellm]",
# ]
# ///
import asyncio
from pathlib import Path

from ragbits.core.embeddings.vertex_multimodal import VertexAIMultimodelEmbeddings
from ragbits.core.vector_stores.in_memory import InMemoryVectorStore
from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.sources import LocalFileSource
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.dummy import DummyImageProvider

IMAGES_PATH = Path(__file__).parent / "images"


def jpg_example(file_name: str) -> DocumentMeta:
"""
Create a document from a JPG file in the images directory.
"""
return DocumentMeta(document_type=DocumentType.JPG, source=LocalFileSource(path=IMAGES_PATH / file_name))


documents = [
jpg_example("bear.jpg"),
jpg_example("game.jpg"),
jpg_example("tree.jpg"),
DocumentMeta.create_text_document_from_literal("A beautiful teady bear."),
DocumentMeta.create_text_document_from_literal("The constitution of the United States."),
]


async def main() -> None:
"""
Run the example.
"""
embedder = VertexAIMultimodelEmbeddings()
vector_store = InMemoryVectorStore()
router = DocumentProcessorRouter.from_config(
{
# For this example, we want to skip OCR and make sure
# that we test direct image embeddings.
DocumentType.JPG: DummyImageProvider(),
}
)

document_search = DocumentSearch(
embedder=embedder,
vector_store=vector_store,
document_processor_router=router,
)

await document_search.ingest(documents)

all_embeddings = await vector_store.list()
for embedding in all_embeddings:
print(f"Embedding: {embedding.metadata['document_meta']}")
print()

results = await document_search.search("Fluffy teady bear")
print("Results for 'Fluffy teady bear toy':")
for result in results:
document = await result.document_meta.fetch()
print(f"Type: {result.element_type}, Location: {document.local_path}, Text: {result.get_text_representation()}")


if __name__ == "__main__":
asyncio.run(main())
21 changes: 21 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,24 @@ async def embed_text(self, data: list[str]) -> list[list[float]]:
Returns:
List of embeddings for the given strings.
"""

def image_support(self) -> bool: # noqa: PLR6301
"""
Check if the model supports image embeddings.

Returns:
True if the model supports image embeddings, False otherwise.
"""
return False
micpst marked this conversation as resolved.
Show resolved Hide resolved

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
"""
Creates embeddings for the given images.

Args:
images: List of images to get embeddings for.

Returns:
List of embeddings for the given images.
"""
raise NotImplementedError("Image embeddings are not supported by this model.")
170 changes: 170 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/vertex_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import asyncio
import base64

try:
import litellm
from litellm.llms.vertex_ai_and_google_ai_studio.common_utils import VertexAIError
from litellm.main import VertexMultimodalEmbedding

HAS_LITELLM = True
except ImportError:
HAS_LITELLM = False

from ragbits.core.audit import trace
from ragbits.core.embeddings import Embeddings
from ragbits.core.embeddings.exceptions import (
EmbeddingResponseError,
EmbeddingStatusError,
)


class VertexAIMultimodelEmbeddings(Embeddings):
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Client for creating text embeddings using LiteLLM API.
"""

VERTEX_AI_PREFIX = "vertex_ai/"

def __init__(
self,
model: str = "multimodalembedding",
api_base: str | None = None,
api_key: str | None = None,
concurency: int = 10,
options: dict | None = None,
) -> None:
"""
Constructs the embedding client for multimodal VertexAI models.

Args:
model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding".
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. If not specified, an environment variable will be used.
concurency: The number of concurrent requests to make to the API.
options: Additional options to pass to the API.

Raises:
ImportError: If the 'litellm' extra requirements are not installed.
ValueError: If the chosen model is not supported by VertexAI multimodal embeddings.
"""
if not HAS_LITELLM:
raise ImportError("You need to install the 'litellm' extra requirements to use LiteLLM embeddings models")

super().__init__()
if model.startswith(self.VERTEX_AI_PREFIX):
model = model[len(self.VERTEX_AI_PREFIX) :]

self.model = model
self.api_base = api_base
self.api_key = api_key
self.concurency = concurency
self.options = options or {}

supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
if model not in supported_models:
raise ValueError(f"Model {model} is not supported by VertexAI multimodal embeddings")

async def _embed(self, data: list[dict]) -> list[dict]:
"""
Creates embeddings for the given data. The format is defined in the VertexAI API:
https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/get-multimodal-embeddings

Args:
data: List of instances in the format expected by the VertexAI API.

Returns:
List of embeddings for the given VertexAI instances, each instance is a dictionary
in the format returned by the VertexAI API.

Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
with trace(
data=data,
model=self.model,
api_base=self.api_base,
options=self.options,
) as outputs:
semaphore = asyncio.Semaphore(self.concurency)
try:
response = await asyncio.gather(
*[self._call_litellm(instance, semaphore) for instance in data],
)
except VertexAIError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc

outputs.embeddings = []
for i, embedding in enumerate(response):
if embedding.data is None or not embedding.data:
raise EmbeddingResponseError(f"No embeddings returned for instance {i}")
outputs.embeddings.append(embedding.data[0])

return outputs.embeddings

async def _call_litellm(self, instance: dict, semaphore: asyncio.Semaphore) -> litellm.EmbeddingResponse:
"""
Calls the LiteLLM API to get embeddings for the given data.

Args:
instance: Single VertexAI instance to get embeddings for.
semaphore: Semaphore to limit the number of concurrent requests.

Returns:
List of embeddings for the given LiteLLM instances.
"""
async with semaphore:
response = await litellm.aembedding(
input=[instance],
model=f"{self.VERTEX_AI_PREFIX}{self.model}",
api_base=self.api_base,
api_key=self.api_key,
**self.options,
)

return response

async def embed_text(self, data: list[str]) -> list[list[float]]:
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates embeddings for the given strings.

Args:
data: List of strings to get embeddings for.

Returns:
List of embeddings for the given strings.

Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
response = await self._embed([{"text": text} for text in data])
return [embedding["textEmbedding"] for embedding in response]

def image_support(self) -> bool: # noqa: PLR6301
"""
Check if the model supports image embeddings.

Returns:
True if the model supports image embeddings, False otherwise.
"""
return True

async def embed_image(self, images: list[bytes]) -> list[list[float]]:
"""
Creates embeddings for the given images.

Args:
images: List of images to get embeddings for.

Returns:
List of embeddings for the given images.

Raises:
EmbeddingStatusError: If the embedding API returns an error status code.
EmbeddingResponseError: If the embedding API response is invalid.
"""
images_b64 = (base64.b64encode(image).decode() for image in images)
response = await self._embed([{"image": {"bytesBase64Encoded": image}} for image in images_b64])

return [embedding["imageEmbedding"] for embedding in response]
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Sequence
from typing import Any

Expand All @@ -8,7 +9,7 @@
from ragbits.core.vector_stores import VectorStore, get_vector_store
from ragbits.core.vector_stores.base import VectorStoreOptions
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.documents.sources import Source
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
Expand Down Expand Up @@ -170,5 +171,22 @@ async def insert_elements(self, elements: list[Element]) -> None:
elements: The list of Elements to insert.
"""
vectors = await self.embedder.embed_text([element.get_text_for_embedding() for element in elements])

image_elements = [element for element in elements if isinstance(element, ImageElement)]
ludwiktrammer marked this conversation as resolved.
Show resolved Hide resolved
entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors, strict=False)]

if image_elements and self.embedder.image_support():
image_vectors = await self.embedder.embed_image([element.image_bytes for element in image_elements])
entries.extend(
[
element.to_vector_db_entry(vector)
for element, vector in zip(image_elements, image_vectors, strict=False)
]
)
elif image_elements:
warnings.warn(
f"Image elements are not supported by the embedder {self.embedder}. "
f"Skipping {len(image_elements)} image elements."
)

await self.vector_store.store(entries)
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def id(self) -> str:
"""
id_components = [
self.document_meta.id,
self.element_type,
self.get_text_for_embedding(),
self.get_text_representation(),
str(self.location),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
DocumentType,
TextDocument,
)
from ragbits.document_search.documents.element import Element, TextElement
from ragbits.document_search.documents.element import Element, ImageElement, TextElement
from ragbits.document_search.ingestion.providers.base import BaseProvider


Expand Down Expand Up @@ -31,3 +31,37 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
if isinstance(document, TextDocument):
return [TextElement(content=document.content, document_meta=document_meta)]
return []


class DummyImageProvider(BaseProvider):
"""
This is a simple provider that returns an ImageElement with the content of the image
and empty text metadata.
"""

SUPPORTED_DOCUMENT_TYPES = {DocumentType.JPG, DocumentType.PNG}

async def process(self, document_meta: DocumentMeta) -> list[Element]:
"""
Process the image document.

Args:
document_meta: The document to process.

Returns:
List with a single ImageElement containing the content of the image.
"""
self.validate_document_type(document_meta.document_type)

document = await document_meta.fetch()
image_path = document.local_path
with open(image_path, "rb") as f:
image_bytes = f.read()
return [
ImageElement(
description="",
ocr_extracted_text="",
image_bytes=image_bytes,
document_meta=document_meta,
)
]
Loading
Loading