diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index e48410bb2..6602759e0 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -161,7 +161,7 @@ async def ingest( """ if isinstance(documents, str): from ragbits.document_search.documents.source_resolver import SourceResolver - sources = SourceResolver.resolve(documents) + sources = await SourceResolver.resolve(documents) else: sources = documents diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py index 36c2a2097..106642c62 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py @@ -12,7 +12,7 @@ class SourceResolver: Example: >>> SourceResolver.register_protocol("gcs", GCSSource) - >>> sources = SourceResolver.resolve("gcs://my-bucket/path/to/files/*") + >>> sources = await SourceResolver.resolve("gcs://my-bucket/path/to/files/*") """ _protocol_handlers: ClassVar[dict[str, type[Source]]] = {} @@ -27,14 +27,13 @@ def register_protocol(cls, protocol: str, source_class: type[Source]) -> None: cls._protocol_handlers[protocol] = source_class @classmethod - def resolve(cls, uri: str) -> Sequence[Source]: + async def resolve(cls, uri: str) -> Sequence[Source]: """Resolve a URI into a sequence of Source objects. The URI format should be: protocol://path For example: - file:///path/to/files/* - gcs://bucket/prefix/* - - s3://bucket/prefix/* Args: uri: The URI to resolve @@ -58,4 +57,4 @@ def resolve(cls, uri: str) -> Sequence[Source]: ) handler_class = cls._protocol_handlers[protocol] - return handler_class.from_uri(path) + return await handler_class.from_uri(path) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index da540981c..39d27ec35 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -67,7 +67,7 @@ async def fetch(self) -> Path: @classmethod @abstractmethod - def from_uri(cls, path: str) -> Sequence["Source"]: + async def from_uri(cls, path: str) -> Sequence["Source"]: """Create Source instances from a URI path. The path can contain glob patterns (asterisks) to match multiple sources, but pattern support @@ -132,6 +132,7 @@ class LocalFileSource(Source): """ path: Path + protocol: ClassVar[str] = "file" @property def id(self) -> str: @@ -172,25 +173,34 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou return [cls(path=file_path) for file_path in path.glob(file_pattern)] @classmethod - def from_uri(cls, path: str) -> Sequence["LocalFileSource"]: + async def from_uri(cls, path: str) -> Sequence["LocalFileSource"]: """Create LocalFileSource instances from a URI path. Supports full glob patterns via Path.glob: - - '*' matches any number of characters except path separators - - '**' matches any number of characters including path separators + - "**/*.txt" - all .txt files in any subdirectory + - "*.py" - all Python files in the current directory + - "**/*" - all files in any subdirectory - '?' matches exactly one character Args: - path: The path part of the URI (after file://). Can contain glob patterns. + path: The path part of the URI (after file://) Returns: - A sequence of LocalFileSource objects matching the pattern + A sequence of LocalFileSource objects """ - path_obj = Path(path) - if "*" in path or "?" in path: - # If path contains wildcards, use list_sources with the parent directory - return cls.list_sources(path_obj.parent, path_obj.name) - return [cls(path=path_obj)] + # Handle absolute paths + path = Path(path) + if not path.is_absolute(): + # For relative paths, use current directory as base + path = Path.cwd() / path + + if "*" in str(path): + # If path contains wildcards, use its parent as base + base_path = path.parent + pattern = path.name + return [cls(path=file_path) for file_path in base_path.glob(pattern)] + + return [cls(path=path)] class GCSSource(Source): @@ -201,6 +211,28 @@ class GCSSource(Source): bucket: str object_name: str protocol: ClassVar[str] = "gcs" + _storage: Any | None = None # Storage client for dependency injection + + @classmethod + def set_storage(cls, storage: Any) -> None: + """Set the storage client for all instances. + + Args: + storage: The storage client to use + """ + cls._storage = storage + + async def _get_storage(self) -> Any: + """Get the storage client. + + Returns: + The storage client to use. If none was injected, creates a new one. + """ + if self._storage is not None: + return self._storage + + from gcloud.aio.storage import Storage + return Storage() @property def id(self) -> str: @@ -234,8 +266,8 @@ async def fetch(self) -> Path: path = bucket_local_dir / self.object_name if not path.is_file(): - async with Storage() as client: # type: ignore - # TODO: Add error handling for download + storage = await self._get_storage() + async with storage as client: content = await client.download(self.bucket, self.object_name) Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True) with open(path, mode="wb+") as file_object: @@ -259,7 +291,10 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: Raises: ImportError: If the required 'gcloud-aio-storage' package is not installed """ - async with Storage() as client: + # Create a temporary instance just to get the storage client + temp_source = cls(bucket=bucket, object_name=prefix) + storage = await temp_source._get_storage() + async with storage as client: objects = await client.list_objects(bucket, params={"prefix": prefix}) sources = [] for obj in objects["items"]: @@ -267,7 +302,7 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: return sources @classmethod - def from_uri(cls, path: str) -> Sequence["GCSSource"]: + async def from_uri(cls, path: str) -> Sequence["GCSSource"]: """Create GCSSource instances from a URI path. Supports simple prefix matching with '*' at the end of path. @@ -303,7 +338,7 @@ def from_uri(cls, path: str) -> Sequence["GCSSource"]: ) # Remove the trailing * for GCS prefix listing prefix = prefix[:-1] - return cls.list_sources(bucket=bucket, prefix=prefix) + return await cls.list_sources(bucket=bucket, prefix=prefix) return [cls(bucket=bucket, object_name=prefix)] @@ -365,7 +400,7 @@ async def fetch(self) -> Path: return path @classmethod - def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]: + async def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]: """Create HuggingFaceSource instances from a URI path. Pattern matching is not supported. The path must be in the format: diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 181b420be..81d9698c8 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,6 +1,8 @@ +import os import tempfile from collections.abc import Callable from pathlib import Path +from unittest import mock from unittest.mock import AsyncMock import pytest @@ -8,23 +10,48 @@ from ragbits.core.vector_stores.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search._main import SearchConfig -from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType +from ragbits.document_search.documents.document import ( + Document, + DocumentMeta, + DocumentType, +) from ragbits.document_search.documents.element import TextElement -from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.documents.sources import ( + GCSSource, + HuggingFaceSource, + LocalFileSource, +) from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter -from ragbits.document_search.ingestion.processor_strategies.batched import BatchedAsyncProcessing +from ragbits.document_search.ingestion.processor_strategies.batched import ( + BatchedAsyncProcessing, +) from ragbits.document_search.ingestion.providers import BaseProvider from ragbits.document_search.ingestion.providers.dummy import DummyProvider CONFIG = { "embedder": {"type": "NoopEmbeddings"}, - "vector_store": {"type": "ragbits.core.vector_stores.in_memory:InMemoryVectorStore"}, + "vector_store": { + "type": "ragbits.core.vector_stores.in_memory:InMemoryVectorStore" + }, "reranker": {"type": "NoopReranker"}, "providers": {"txt": {"type": "DummyProvider"}}, "processing_strategy": {"type": "SequentialProcessing"}, } +# This fixture is used automatically for every test due to autouse=True. +# It ensures source protocols are registered before any test runs. +@pytest.fixture(autouse=True) +def setup_sources(): + # Import sources to ensure protocols are registered + from ragbits.document_search.documents.sources import ( + GCSSource, + HuggingFaceSource, + LocalFileSource, + ) + yield + + @pytest.mark.parametrize( ("document", "expected"), [ @@ -202,3 +229,260 @@ async def test_document_search_with_batched(): assert len(await vectore_store.list()) == 12 assert len(results) == 12 + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_uri_basic(): + # Setup + with tempfile.TemporaryDirectory() as temp_dir: + test_file = Path(temp_dir) / "test.txt" + test_file.write_text("Test content") + + document_search = DocumentSearch.from_config(CONFIG) + + # Test ingesting from URI + await document_search.ingest(f"file://{test_file}") + + # Verify + results = await document_search.search("Test content") + assert len(results) == 1 + assert results[0].content == "Test content" + assert isinstance(results[0].document_meta.source, LocalFileSource) + assert str(results[0].document_meta.source.path) == str(test_file) + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_uri_with_wildcard(): + # Setup + with tempfile.TemporaryDirectory() as temp_dir: + # Create multiple test files + test_files = [ + (Path(temp_dir) / "test1.txt", "First test content"), + (Path(temp_dir) / "test2.txt", "Second test content"), + (Path(temp_dir) / "other.txt", "Other content"), + ] + for path, content in test_files: + path.write_text(content) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test ingesting from URI with wildcard + await document_search.ingest(f"file://{temp_dir}/test*.txt") + + # Verify only matching files were ingested + results = await document_search.search("test content") + assert len(results) == 2 + + contents = {result.content for result in results} + assert contents == {"First test content", "Second test content"} + + # Verify sources are correct + sources = {str(result.document_meta.source.path) for result in results} + expected_sources = {str(test_files[0][0]), str(test_files[1][0])} + assert sources == expected_sources + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_basic(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock(return_value=b"GCS test content") + storage_mock.list_objects = mock.AsyncMock( + return_value={ + "items": [{"name": "folder/test1.txt"}, {"name": "folder/test2.txt"}] + } + ) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test single file + await document_search.ingest("gcs://test-bucket/folder/test1.txt") + results = await document_search.search("GCS test content") + assert len(results) == 1 + assert isinstance(results[0].document_meta.source, GCSSource) + assert results[0].document_meta.source.bucket == "test-bucket" + assert results[0].document_meta.source.object_name == "folder/test1.txt" + + # Clean up + GCSSource.set_storage(None) + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_with_wildcard(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock( + side_effect=[b"GCS test content 1", b"GCS test content 2"] + ) + storage_mock.list_objects = mock.AsyncMock( + return_value={ + "items": [{"name": "folder/test1.txt"}, {"name": "folder/test2.txt"}] + } + ) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test wildcard ingestion + await document_search.ingest("gcs://test-bucket/folder/*") + + # Verify both files were ingested + results = await document_search.search("GCS test content") + assert len(results) == 2 + + # Verify first file + assert isinstance(results[0].document_meta.source, GCSSource) + assert results[0].document_meta.source.bucket == "test-bucket" + assert results[0].document_meta.source.object_name == "folder/test1.txt" + + # Verify second file + assert isinstance(results[1].document_meta.source, GCSSource) + assert results[1].document_meta.source.bucket == "test-bucket" + assert results[1].document_meta.source.object_name == "folder/test2.txt" + + # Clean up + GCSSource.set_storage(None) + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_invalid_pattern(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test invalid patterns + with pytest.raises( + ValueError, match="GCSSource only supports '\\*' at the end of path" + ): + await document_search.ingest("gcs://test-bucket/folder/**.txt") + + with pytest.raises( + ValueError, match="GCSSource only supports '\\*' at the end of path" + ): + await document_search.ingest("gcs://test-bucket/folder/test?.txt") + + with pytest.raises( + ValueError, match="GCSSource only supports '\\*' at the end of path" + ): + await document_search.ingest("gcs://test-bucket/folder/test*file.txt") + + # Test empty list response + storage_mock.list_objects = mock.AsyncMock(return_value={"items": []}) + await document_search.ingest("gcs://test-bucket/folder/*") + results = await document_search.search("GCS test content") + assert len(results) == 0 + + # Clean up + GCSSource.set_storage(None) + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_huggingface_uri_basic(): + # Create mock data + mock_data = [{ + "content": "HuggingFace test content", + "source": "dataset_name/train/test.txt" # Must be .txt for TextDocument + }] + + # Create a simple dataset class that supports skip/take + class MockDataset: + def __init__(self, data): + self.data = data + self.current_index = 0 + + def skip(self, n): + self.current_index = n + return self + + def take(self, n): + return self + + def __iter__(self): + if self.current_index < len(self.data): + return iter(self.data[self.current_index:self.current_index + 1]) + return iter([]) + + # Mock dataset loading and embeddings + dataset = MockDataset(mock_data) + embeddings_mock = AsyncMock() + embeddings_mock.embed_text.return_value = [[0.1, 0.1]] # Non-zero embeddings + + # Create providers dict with actual provider instance + providers = { + DocumentType.TXT: DummyProvider() + } + + # Mock vector store to track operations + vector_store = InMemoryVectorStore() + + # Create a temporary directory for storing test files + with tempfile.TemporaryDirectory() as temp_dir: + # Set the environment variable for local storage + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + storage_dir = Path(temp_dir) + + # Create the source directory and file + source_dir = storage_dir / "dataset_name/train" + source_dir.mkdir(parents=True, exist_ok=True) + source_file = source_dir / "test.txt" + with open(source_file, mode="w", encoding="utf-8") as file: + file.write("HuggingFace test content") + + with mock.patch("ragbits.document_search.documents.sources.load_dataset", return_value=dataset), \ + mock.patch("ragbits.document_search.documents.sources.get_local_storage_dir", return_value=storage_dir): + document_search = DocumentSearch( + embedder=embeddings_mock, + vector_store=vector_store, + document_processor_router=DocumentProcessorRouter.from_config(providers), + ) + + await document_search.ingest("huggingface://dataset_name/train/0") + + results = await document_search.search("HuggingFace test content") + assert len(results) == 1 + assert isinstance(results[0].document_meta.source, HuggingFaceSource) + assert results[0].document_meta.source.path == "dataset_name" + assert results[0].document_meta.source.split == "train" + assert results[0].document_meta.source.row == 0