Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Vertex embeddings to RAG package. #33593

Merged
merged 1 commit into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

"""Tests for apache_beam.ml.rag.embeddings.huggingface."""

import shutil
import tempfile
import unittest

Expand Down Expand Up @@ -73,6 +74,9 @@ def setUp(self):
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
expected = [
Chunk(
Expand Down
97 changes: 97 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


# Vertex AI Python SDK is required for this module.
# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long
# to install Vertex AI Python SDK.

"""RAG-specific embedding implementations using Vertex AI models."""

from typing import Optional

from google.auth.credentials import Credentials

import apache_beam as beam
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.rag.embeddings.base import create_rag_adapter
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.transforms.base import EmbeddingsManager
from apache_beam.ml.transforms.base import _TextEmbeddingHandler
from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE
from apache_beam.ml.transforms.embeddings.vertex_ai import _VertexAITextEmbeddingHandler

try:
import vertexai
except ImportError:
vertexai = None


class VertexAITextEmbeddings(EmbeddingsManager):
def __init__(
self,
model_name: str,
*,
title: Optional[str] = None,
task_type: str = DEFAULT_TASK_TYPE,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[Credentials] = None,
**kwargs):
"""Utilizes Vertex AI text embeddings for semantic search and RAG
pipelines.

Args:
model_name: Name of the Vertex AI text embedding model
title: Optional title for the text content
task_type: Task type for embeddings (default: RETRIEVAL_DOCUMENT)
project: GCP project ID
location: GCP location
credentials: Optional GCP credentials
**kwargs: Additional arguments passed to EmbeddingsManager including
ModelHandler inference_args.
"""
if not vertexai:
raise ImportError(
"vertexai is required to use VertexAITextEmbeddings. "
"Please install it with `pip install google-cloud-aiplatform`")

super().__init__(type_adapter=create_rag_adapter(), **kwargs)
self.model_name = model_name
self.title = title
self.task_type = task_type
self.project = project
self.location = location
self.credentials = credentials

def get_model_handler(self):
"""Returns model handler configured with RAG adapter."""
return _VertexAITextEmbeddingHandler(
model_name=self.model_name,
title=self.title,
task_type=self.task_type,
project=self.project,
location=self.location,
credentials=self.credentials,
)

def get_ptransform_for_processing(
self, **kwargs
) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]:
"""Returns PTransform that uses the RAG adapter."""
return RunInference(
model_handler=_TextEmbeddingHandler(self),
inference_args=self.inference_args).with_output_types(Chunk)
110 changes: 110 additions & 0 deletions sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for apache_beam.ml.rag.embeddings.vertex_ai."""

import shutil
import tempfile
import unittest

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.ml.rag.types import Content
from apache_beam.ml.rag.types import Embedding
from apache_beam.ml.transforms.base import MLTransform
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

# pylint: disable=ungrouped-imports
try:
import vertexai # pylint: disable=unused-import
from apache_beam.ml.rag.embeddings.vertex_ai import VertexAITextEmbeddings
VERTEX_AI_AVAILABLE = True
except ImportError:
VERTEX_AI_AVAILABLE = False


def chunk_approximately_equals(expected, actual):
"""Compare embeddings allowing for numerical differences."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False

return (
expected.id == actual.id and expected.metadata == actual.metadata and
expected.content == actual.content and
len(expected.embedding.dense_embedding) == len(
actual.embedding.dense_embedding) and
all(isinstance(x, float) for x in actual.embedding.dense_embedding))


@unittest.skipIf(
not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available")
class VertexAITextEmbeddingsTest(unittest.TestCase):
def setUp(self):
self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_')
self.test_chunks = [
Chunk(
content=Content(text="This is a test sentence."),
id="1",
metadata={
"source": "test.txt", "language": "en"
}),
Chunk(
content=Content(text="Another example."),
id="2",
metadata={
"source": "test.txt", "language": "en"
})
]

def tearDown(self) -> None:
shutil.rmtree(self.artifact_location)

def test_embedding_pipeline(self):
# gecko@002 produces 768-dimensional embeddings
expected = [
Chunk(
id="1",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="This is a test sentence.")),
Chunk(
id="2",
embedding=Embedding(dense_embedding=[0.0] * 768),
metadata={
"source": "test.txt", "language": "en"
},
content=Content(text="Another example."))
]

embedder = VertexAITextEmbeddings(model_name="textembedding-gecko@002")

with TestPipeline() as p:
embeddings = (
p
| beam.Create(self.test_chunks)
| MLTransform(write_artifact_location=self.artifact_location).
with_transform(embedder))

assert_that(
embeddings, equal_to(expected, equals_fn=chunk_approximately_equals))


if __name__ == '__main__':
unittest.main()
Loading