diff --git a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py index aa63d13025a1..f0b9316dcee8 100644 --- a/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py +++ b/sdks/python/apache_beam/ml/rag/embeddings/huggingface_test.py @@ -16,6 +16,7 @@ """Tests for apache_beam.ml.rag.embeddings.huggingface.""" +import shutil import tempfile import unittest @@ -73,6 +74,9 @@ def setUp(self): }) ] + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + def test_embedding_pipeline(self): expected = [ Chunk( diff --git a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py new file mode 100644 index 000000000000..c960cd66a2a4 --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Vertex AI Python SDK is required for this module. +# Follow https://cloud.google.com/vertex-ai/docs/python-sdk/use-vertex-ai-python-sdk # pylint: disable=line-too-long +# to install Vertex AI Python SDK. + +"""RAG-specific embedding implementations using Vertex AI models.""" + +from typing import Optional + +from google.auth.credentials import Credentials + +import apache_beam as beam +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.rag.embeddings.base import create_rag_adapter +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.transforms.base import EmbeddingsManager +from apache_beam.ml.transforms.base import _TextEmbeddingHandler +from apache_beam.ml.transforms.embeddings.vertex_ai import DEFAULT_TASK_TYPE +from apache_beam.ml.transforms.embeddings.vertex_ai import _VertexAITextEmbeddingHandler + +try: + import vertexai +except ImportError: + vertexai = None + + +class VertexAITextEmbeddings(EmbeddingsManager): + def __init__( + self, + model_name: str, + *, + title: Optional[str] = None, + task_type: str = DEFAULT_TASK_TYPE, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[Credentials] = None, + **kwargs): + """Utilizes Vertex AI text embeddings for semantic search and RAG + pipelines. + + Args: + model_name: Name of the Vertex AI text embedding model + title: Optional title for the text content + task_type: Task type for embeddings (default: RETRIEVAL_DOCUMENT) + project: GCP project ID + location: GCP location + credentials: Optional GCP credentials + **kwargs: Additional arguments passed to EmbeddingsManager including + ModelHandler inference_args. + """ + if not vertexai: + raise ImportError( + "vertexai is required to use VertexAITextEmbeddings. " + "Please install it with `pip install google-cloud-aiplatform`") + + super().__init__(type_adapter=create_rag_adapter(), **kwargs) + self.model_name = model_name + self.title = title + self.task_type = task_type + self.project = project + self.location = location + self.credentials = credentials + + def get_model_handler(self): + """Returns model handler configured with RAG adapter.""" + return _VertexAITextEmbeddingHandler( + model_name=self.model_name, + title=self.title, + task_type=self.task_type, + project=self.project, + location=self.location, + credentials=self.credentials, + ) + + def get_ptransform_for_processing( + self, **kwargs + ) -> beam.PTransform[beam.PCollection[Chunk], beam.PCollection[Chunk]]: + """Returns PTransform that uses the RAG adapter.""" + return RunInference( + model_handler=_TextEmbeddingHandler(self), + inference_args=self.inference_args).with_output_types(Chunk) diff --git a/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py new file mode 100644 index 000000000000..4e5ad8046a8a --- /dev/null +++ b/sdks/python/apache_beam/ml/rag/embeddings/vertex_ai_test.py @@ -0,0 +1,110 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for apache_beam.ml.rag.embeddings.vertex_ai.""" + +import shutil +import tempfile +import unittest + +import apache_beam as beam +from apache_beam.ml.rag.types import Chunk +from apache_beam.ml.rag.types import Content +from apache_beam.ml.rag.types import Embedding +from apache_beam.ml.transforms.base import MLTransform +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +# pylint: disable=ungrouped-imports +try: + import vertexai # pylint: disable=unused-import + from apache_beam.ml.rag.embeddings.vertex_ai import VertexAITextEmbeddings + VERTEX_AI_AVAILABLE = True +except ImportError: + VERTEX_AI_AVAILABLE = False + + +def chunk_approximately_equals(expected, actual): + """Compare embeddings allowing for numerical differences.""" + if not isinstance(expected, Chunk) or not isinstance(actual, Chunk): + return False + + return ( + expected.id == actual.id and expected.metadata == actual.metadata and + expected.content == actual.content and + len(expected.embedding.dense_embedding) == len( + actual.embedding.dense_embedding) and + all(isinstance(x, float) for x in actual.embedding.dense_embedding)) + + +@unittest.skipIf( + not VERTEX_AI_AVAILABLE, "Vertex AI dependencies not available") +class VertexAITextEmbeddingsTest(unittest.TestCase): + def setUp(self): + self.artifact_location = tempfile.mkdtemp(prefix='vertex_ai_') + self.test_chunks = [ + Chunk( + content=Content(text="This is a test sentence."), + id="1", + metadata={ + "source": "test.txt", "language": "en" + }), + Chunk( + content=Content(text="Another example."), + id="2", + metadata={ + "source": "test.txt", "language": "en" + }) + ] + + def tearDown(self) -> None: + shutil.rmtree(self.artifact_location) + + def test_embedding_pipeline(self): + # gecko@002 produces 768-dimensional embeddings + expected = [ + Chunk( + id="1", + embedding=Embedding(dense_embedding=[0.0] * 768), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="This is a test sentence.")), + Chunk( + id="2", + embedding=Embedding(dense_embedding=[0.0] * 768), + metadata={ + "source": "test.txt", "language": "en" + }, + content=Content(text="Another example.")) + ] + + embedder = VertexAITextEmbeddings(model_name="textembedding-gecko@002") + + with TestPipeline() as p: + embeddings = ( + p + | beam.Create(self.test_chunks) + | MLTransform(write_artifact_location=self.artifact_location). + with_transform(embedder)) + + assert_that( + embeddings, equal_to(expected, equals_fn=chunk_approximately_equals)) + + +if __name__ == '__main__': + unittest.main()