Skip to content

Commit

Permalink
Add tests for document ingestion from URIs:
Browse files Browse the repository at this point in the history
  - Basic file URI ingestion
  - Wildcard pattern matching
  • Loading branch information
maciejklimek committed Dec 27, 2024
1 parent 9793347 commit 6be3c3f
Show file tree
Hide file tree
Showing 4 changed files with 344 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def ingest(
"""
if isinstance(documents, str):
from ragbits.document_search.documents.source_resolver import SourceResolver
sources = SourceResolver.resolve(documents)
sources = await SourceResolver.resolve(documents)
else:
sources = documents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class SourceResolver:
Example:
>>> SourceResolver.register_protocol("gcs", GCSSource)
>>> sources = SourceResolver.resolve("gcs://my-bucket/path/to/files/*")
>>> sources = await SourceResolver.resolve("gcs://my-bucket/path/to/files/*")
"""
_protocol_handlers: ClassVar[dict[str, type[Source]]] = {}

Expand All @@ -27,14 +27,13 @@ def register_protocol(cls, protocol: str, source_class: type[Source]) -> None:
cls._protocol_handlers[protocol] = source_class

@classmethod
def resolve(cls, uri: str) -> Sequence[Source]:
async def resolve(cls, uri: str) -> Sequence[Source]:
"""Resolve a URI into a sequence of Source objects.
The URI format should be: protocol://path
For example:
- file:///path/to/files/*
- gcs://bucket/prefix/*
- s3://bucket/prefix/*
Args:
uri: The URI to resolve
Expand All @@ -58,4 +57,4 @@ def resolve(cls, uri: str) -> Sequence[Source]:
)

handler_class = cls._protocol_handlers[protocol]
return handler_class.from_uri(path)
return await handler_class.from_uri(path)
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def fetch(self) -> Path:

@classmethod
@abstractmethod
def from_uri(cls, path: str) -> Sequence["Source"]:
async def from_uri(cls, path: str) -> Sequence["Source"]:
"""Create Source instances from a URI path.
The path can contain glob patterns (asterisks) to match multiple sources, but pattern support
Expand Down Expand Up @@ -132,6 +132,7 @@ class LocalFileSource(Source):
"""

path: Path
protocol: ClassVar[str] = "file"

@property
def id(self) -> str:
Expand Down Expand Up @@ -172,25 +173,34 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou
return [cls(path=file_path) for file_path in path.glob(file_pattern)]

@classmethod
def from_uri(cls, path: str) -> Sequence["LocalFileSource"]:
async def from_uri(cls, path: str) -> Sequence["LocalFileSource"]:
"""Create LocalFileSource instances from a URI path.
Supports full glob patterns via Path.glob:
- '*' matches any number of characters except path separators
- '**' matches any number of characters including path separators
- "**/*.txt" - all .txt files in any subdirectory
- "*.py" - all Python files in the current directory
- "**/*" - all files in any subdirectory
- '?' matches exactly one character
Args:
path: The path part of the URI (after file://). Can contain glob patterns.
path: The path part of the URI (after file://)
Returns:
A sequence of LocalFileSource objects matching the pattern
A sequence of LocalFileSource objects
"""
path_obj = Path(path)
if "*" in path or "?" in path:
# If path contains wildcards, use list_sources with the parent directory
return cls.list_sources(path_obj.parent, path_obj.name)
return [cls(path=path_obj)]
# Handle absolute paths
path = Path(path)
if not path.is_absolute():
# For relative paths, use current directory as base
path = Path.cwd() / path

if "*" in str(path):
# If path contains wildcards, use its parent as base
base_path = path.parent
pattern = path.name
return [cls(path=file_path) for file_path in base_path.glob(pattern)]

return [cls(path=path)]


class GCSSource(Source):
Expand All @@ -201,6 +211,28 @@ class GCSSource(Source):
bucket: str
object_name: str
protocol: ClassVar[str] = "gcs"
_storage: Any | None = None # Storage client for dependency injection

@classmethod
def set_storage(cls, storage: Any) -> None:
"""Set the storage client for all instances.
Args:
storage: The storage client to use
"""
cls._storage = storage

async def _get_storage(self) -> Any:
"""Get the storage client.
Returns:
The storage client to use. If none was injected, creates a new one.
"""
if self._storage is not None:
return self._storage

from gcloud.aio.storage import Storage
return Storage()

@property
def id(self) -> str:
Expand Down Expand Up @@ -234,8 +266,8 @@ async def fetch(self) -> Path:
path = bucket_local_dir / self.object_name

if not path.is_file():
async with Storage() as client: # type: ignore
# TODO: Add error handling for download
storage = await self._get_storage()
async with storage as client:
content = await client.download(self.bucket, self.object_name)
Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True)
with open(path, mode="wb+") as file_object:
Expand All @@ -259,15 +291,18 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
async with Storage() as client:
# Create a temporary instance just to get the storage client
temp_source = cls(bucket=bucket, object_name=prefix)
storage = await temp_source._get_storage()
async with storage as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources

@classmethod
def from_uri(cls, path: str) -> Sequence["GCSSource"]:
async def from_uri(cls, path: str) -> Sequence["GCSSource"]:
"""Create GCSSource instances from a URI path.
Supports simple prefix matching with '*' at the end of path.
Expand Down Expand Up @@ -303,7 +338,7 @@ def from_uri(cls, path: str) -> Sequence["GCSSource"]:
)
# Remove the trailing * for GCS prefix listing
prefix = prefix[:-1]
return cls.list_sources(bucket=bucket, prefix=prefix)
return await cls.list_sources(bucket=bucket, prefix=prefix)

return [cls(bucket=bucket, object_name=prefix)]

Expand Down Expand Up @@ -365,7 +400,7 @@ async def fetch(self) -> Path:
return path

@classmethod
def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]:
async def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]:
"""Create HuggingFaceSource instances from a URI path.
Pattern matching is not supported. The path must be in the format:
Expand Down
Loading

0 comments on commit 6be3c3f

Please sign in to comment.