Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add langchain dep to ML tests. #33607

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 34 additions & 20 deletions sdks/python/apache_beam/ml/rag/chunking/langchain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@

"""Tests for apache_beam.ml.rag.chunking.langchain."""

import functools
import unittest

import apache_beam as beam
from apache_beam.ml.rag.types import Chunk
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import BeamAssertException
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import is_not_empty

try:
from apache_beam.ml.rag.chunking.langchain import LangChainChunker
Expand All @@ -41,13 +43,10 @@
TRANSFORMERS_AVAILABLE = False


def chunk_equals(expected, actual):
"""Custom equality function for Chunk objects."""
if not isinstance(expected, Chunk) or not isinstance(actual, Chunk):
return False
return (
expected.content == actual.content and expected.index == actual.index and
expected.metadata == actual.metadata)
def assert_true(elements, assert_fn, error_message_fn):
if not assert_fn(elements):
raise BeamAssertException(error_message_fn(elements))
return True


@unittest.skipIf(not LANGCHAIN_AVAILABLE, 'langchain is not installed.')
Expand Down Expand Up @@ -83,9 +82,15 @@ def test_no_metadata_fields(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')

assert_that(chunks, lambda x: all(c.metadata == {} for c in x))
assert_that(
chunks,
functools.partial(
assert_true,
assert_fn=lambda x: (all(c.metadata == {} for c in x)),
error_message_fn=lambda x: f"Expected empty metadata, actual {x}")
)

def test_multiple_metadata_fields(self):
"""Test chunking with multiple metadata fields."""
Expand All @@ -94,6 +99,7 @@ def test_multiple_metadata_fields(self):
document_field='content',
metadata_fields=['source', 'language'],
text_splitter=splitter)
expected_metadata = {'source': 'simple.txt', 'language': 'en'}

with TestPipeline() as p:
chunks = (
Expand All @@ -102,18 +108,20 @@ def test_multiple_metadata_fields(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(
chunks,
lambda x: all(
c.metadata == {
'source': 'simple.txt', 'language': 'en'
} for c in x))
functools.partial(
assert_true,
assert_fn=lambda x: all(
c.metadata == expected_metadata for c in x),
error_message_fn=lambda x:
f"Expected metadata {expected_metadata}, actual {x}"))

def test_recursive_splitter_no_overlap(self):
"""Test RecursiveCharacterTextSplitter with no overlap."""
splitter = RecursiveCharacterTextSplitter(
chunk_size=30, chunk_overlap=0, separators=[". "])
chunk_size=30, chunk_overlap=0, separators=[".", " "])
provider = LangChainChunker(
document_field='content',
metadata_fields=['source'],
Expand All @@ -126,8 +134,14 @@ def test_recursive_splitter_no_overlap(self):
| provider.get_ptransform_for_processing())
chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks, lambda x: all(len(c.content.text) <= 30 for c in x))
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(
chunks,
functools.partial(
assert_true,
assert_fn=lambda x: all(len(c.content.text) <= 30 for c in x),
error_message_fn=lambda x: f"Expected len(chunk) <= 30, \
actual {[len(c.content.text) for c in x]}"))

@unittest.skipIf(not TRANSFORMERS_AVAILABLE, "transformers not available")
def test_huggingface_tokenizer_splitter(self):
Expand Down Expand Up @@ -155,13 +169,13 @@ def check_token_lengths(chunks):
# Verify each chunk's token length is within limits
num_tokens = len(tokenizer.encode(chunk.content.text))
if not num_tokens <= 10:
raise AssertionError(
raise BeamAssertException(
f"Chunk has {num_tokens} tokens, expected <= 10")
return True

chunks_count = chunks | beam.combiners.Count.Globally()

assert_that(chunks_count, lambda x: x[0] > 0, 'Has chunks')
assert_that(chunks_count, is_not_empty(), 'Has chunks')
assert_that(chunks, check_token_lengths)

def test_invalid_document_field(self):
Expand Down
2 changes: 2 additions & 0 deletions sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ def get_portability_package_data():
'ml_test': [
'datatable',
'embeddings',
'langchain',
'onnxruntime',
'sentence-transformers',
'skl2onnx',
Expand All @@ -505,6 +506,7 @@ def get_portability_package_data():
'datatable',
'embeddings',
'onnxruntime',
'langchain',
'sentence-transformers',
'skl2onnx',
'pillow',
Expand Down
Loading