Skip to content

Commit

Permalink
code
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejklimek committed Dec 27, 2024
1 parent 64a56c0 commit 9793347
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,30 @@ async def search(self, query: str, config: SearchConfig | None = None) -> Sequen
@traceable
async def ingest(
self,
documents: Sequence[DocumentMeta | Document | Source],
documents: str | Sequence[DocumentMeta | Document | Source],
document_processor: BaseProvider | None = None,
) -> None:
"""
Ingest multiple documents.
"""Ingest documents into the search index.
Args:
documents: The documents or metadata of the documents to ingest.
documents: Either:
- A URI string (e.g., "gcs://bucket/*") to specify source location(s), or
- A sequence of documents or metadata of the documents to ingest
URI format depends on the source type, for example:
- "file:///path/to/files/*.txt"
- "gcs://bucket/folder/*"
- "huggingface://dataset/split/row"
document_processor: The document processor to use. If not provided, the document processor will be
determined based on the document metadata.
"""
if isinstance(documents, str):
from ragbits.document_search.documents.source_resolver import SourceResolver
sources = SourceResolver.resolve(documents)
else:
sources = documents

elements = await self.processing_strategy.process_documents(
documents, self.document_processor_router, document_processor
sources, self.document_processor_router, document_processor
)
await self._remove_entries_with_same_sources(elements)
await self.insert_elements(elements)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from collections.abc import Sequence
from typing import ClassVar

from ragbits.document_search.documents.sources import Source


class SourceResolver:
"""Registry for source URI protocols and their handlers.
This class provides a mechanism to register and resolve different source protocols (like 'file://', 'gcs://', etc.)
to their corresponding Source implementations.
Example:
>>> SourceResolver.register_protocol("gcs", GCSSource)
>>> sources = SourceResolver.resolve("gcs://my-bucket/path/to/files/*")
"""
_protocol_handlers: ClassVar[dict[str, type[Source]]] = {}

@classmethod
def register_protocol(cls, protocol: str, source_class: type[Source]) -> None:
"""Register a source class for a specific protocol.
Args:
protocol: The protocol identifier (e.g., 'file', 'gcs', 's3')
source_class: The Source subclass that handles this protocol
"""
cls._protocol_handlers[protocol] = source_class

@classmethod
def resolve(cls, uri: str) -> Sequence[Source]:
"""Resolve a URI into a sequence of Source objects.
The URI format should be: protocol://path
For example:
- file:///path/to/files/*
- gcs://bucket/prefix/*
- s3://bucket/prefix/*
Args:
uri: The URI to resolve
Returns:
A sequence of Source objects
Raises:
ValueError: If the URI format is invalid or the protocol is not supported
"""
try:
protocol, path = uri.split("://", 1)
except ValueError:
raise ValueError(f"Invalid URI format: {uri}. Expected format: protocol://path")

if protocol not in cls._protocol_handlers:
supported = ", ".join(sorted(cls._protocol_handlers.keys()))
raise ValueError(
f"Unsupported protocol: {protocol}. "
f"Supported protocols are: {supported}"
)

handler_class = cls._protocol_handlers[protocol]
return handler_class.from_uri(path)
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import tempfile
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import suppress
from pathlib import Path
from typing import Any, ClassVar
Expand Down Expand Up @@ -29,6 +30,7 @@ class Source(BaseModel, ABC):

# Registry of all subclasses by their unique identifier
_registry: ClassVar[dict[str, type["Source"]]] = {}
protocol: ClassVar[str | None] = None

@classmethod
def class_identifier(cls) -> str:
Expand Down Expand Up @@ -64,9 +66,34 @@ async def fetch(self) -> Path:
"""

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401
Source._registry[cls.class_identifier()] = cls
@abstractmethod
def from_uri(cls, path: str) -> Sequence["Source"]:
"""Create Source instances from a URI path.
The path can contain glob patterns (asterisks) to match multiple sources, but pattern support
varies by source type. Each source implementation defines which patterns it supports:
- LocalFileSource: Supports full glob patterns ('*', '**', etc.) via Path.glob
- GCSSource: Supports simple prefix matching with '*' at the end of path
- HuggingFaceSource: Does not support glob patterns
Args:
path: The path part of the URI (after protocol://). Pattern support depends on source type.
Returns:
A sequence of Source objects matching the path pattern
Raises:
ValueError: If the path contains unsupported pattern for this source type
"""

@classmethod
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
Source._registry[cls.class_identifier()] = cls
if cls.protocol is not None:
from ragbits.document_search.documents.source_resolver import SourceResolver
SourceResolver.register_protocol(cls.protocol, cls)


class SourceDiscriminator:
Expand Down Expand Up @@ -144,6 +171,27 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou
"""
return [cls(path=file_path) for file_path in path.glob(file_pattern)]

@classmethod
def from_uri(cls, path: str) -> Sequence["LocalFileSource"]:
"""Create LocalFileSource instances from a URI path.
Supports full glob patterns via Path.glob:
- '*' matches any number of characters except path separators
- '**' matches any number of characters including path separators
- '?' matches exactly one character
Args:
path: The path part of the URI (after file://). Can contain glob patterns.
Returns:
A sequence of LocalFileSource objects matching the pattern
"""
path_obj = Path(path)
if "*" in path or "?" in path:
# If path contains wildcards, use list_sources with the parent directory
return cls.list_sources(path_obj.parent, path_obj.name)
return [cls(path=path_obj)]


class GCSSource(Source):
"""
Expand All @@ -152,6 +200,7 @@ class GCSSource(Source):

bucket: str
object_name: str
protocol: ClassVar[str] = "gcs"

@property
def id(self) -> str:
Expand Down Expand Up @@ -217,6 +266,47 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources

@classmethod
def from_uri(cls, path: str) -> Sequence["GCSSource"]:
"""Create GCSSource instances from a URI path.
Supports simple prefix matching with '*' at the end of path.
For example:
- "bucket/folder/*" - matches all files in the folder
- "bucket/folder/prefix*" - matches all files starting with prefix
More complex patterns like '**' or '?' are not supported.
Args:
path: The path part of the URI (after gcs://). Can end with '*' for pattern matching.
Returns:
A sequence of GCSSource objects matching the pattern
Raises:
ValueError: If an unsupported pattern is used
"""
if "**" in path or "?" in path:
raise ValueError(
"GCSSource only supports '*' at the end of path. "
"Patterns like '**' or '?' are not supported."
)

# Split into bucket and prefix
bucket, prefix = path.split("/", 1) if "/" in path else (path, "")

if "*" in prefix:
if not prefix.endswith("*"):
raise ValueError(
"GCSSource only supports '*' at the end of path. "
f"Invalid pattern: {prefix}"
)
# Remove the trailing * for GCS prefix listing
prefix = prefix[:-1]
return cls.list_sources(bucket=bucket, prefix=prefix)

return [cls(bucket=bucket, object_name=prefix)]


class HuggingFaceSource(Source):
"""
Expand All @@ -226,6 +316,7 @@ class HuggingFaceSource(Source):
path: str
split: str = "train"
row: int
protocol: ClassVar[str] = "huggingface"

@property
def id(self) -> str:
Expand Down Expand Up @@ -273,6 +364,37 @@ async def fetch(self) -> Path:

return path

@classmethod
def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]:
"""Create HuggingFaceSource instances from a URI path.
Pattern matching is not supported. The path must be in the format:
huggingface://dataset_path/split/row
Args:
path: The path part of the URI (after huggingface://)
Returns:
A sequence containing a single HuggingFaceSource
Raises:
ValueError: If the path contains patterns or has invalid format
"""
if "*" in path or "?" in path:
raise ValueError(
"HuggingFaceSource does not support patterns. "
"Path must be in format: dataset_path/split/row"
)

try:
dataset_path, split, row = path.split("/")
return [cls(path=dataset_path, split=split, row=int(row))]
except ValueError:
raise ValueError(
"Invalid HuggingFace path format. "
"Expected: dataset_path/split/row"
)


def get_local_storage_dir() -> Path:
"""
Expand Down

0 comments on commit 9793347

Please sign in to comment.