Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): Allow to create DocumentSearch instances from config #62

27 changes: 27 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import sys

from .base import Embeddings
from .litellm import LiteLLMEmbeddings
from .local import LocalEmbeddings
from .noop import NoopEmbeddings

__all__ = ["LiteLLMEmbeddings", "LocalEmbeddings", "NoopEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.

Args:
embedder_config : A dictionary containing configuration details for the embedder.

Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

return getattr(module, embeddings_type)(**config)
25 changes: 25 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/noop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from ragbits.core.embeddings.base import Embeddings


class NoopEmbeddings(Embeddings):
"""
A no-op implementation of the Embeddings class.

This class provides a simple embedding method that returns a fixed
embedding vector for each input text. It's mainly useful for testing
or as a placeholder when an actual embedding model is not required.
"""

async def embed_text(self, data: list[str]) -> list[list[float]]:
"""
Embeds a list of strings into a list of vectors.

Args:
data: A list of input text strings to embed.

Returns:
A list of embedding vectors, where each vector
is a fixed value of [0.1, 0.1] for each input string.
"""

return [[0.1, 0.1]] * len(data)
26 changes: 26 additions & 0 deletions packages/ragbits-core/src/ragbits/core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from importlib import import_module
from types import ModuleType
from typing import Any


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any:
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.

Args:
cls_path: A string representing the path to the class or object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:ClassName)
if the class is located in a different module.
default_module: The default module to search for the class if no specific module
is provided in the type string.

Returns:
Any: The object retrieved from the specified or default module.
"""
if ":" in cls_path:
module_stringified, object_stringified = cls_path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)

return getattr(default_module, cls_path)
28 changes: 28 additions & 0 deletions packages/ragbits-core/src/ragbits/core/vector_store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import sys

from ..utils import get_cls_from_config
from .base import VectorStore
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "ChromaDBStore"]

module = sys.modules[__name__]


def get_vector_store(vector_store_config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.

Args:
vector_store_config: A dictionary containing configuration details for the VectorStore.

Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.
"""

vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

return vector_store_cls(**config)
45 changes: 45 additions & 0 deletions packages/ragbits-document-search/examples/from_config_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "ragbits-document-search",
# "ragbits[litellm]",
# ]
# ///
import asyncio

from ragbits.document_search import DocumentSearch
from ragbits.document_search.documents.document import DocumentMeta

documents = [
DocumentMeta.create_text_document_from_literal("RIP boiled water. You will be mist."),
DocumentMeta.create_text_document_from_literal(
"Why doesn't James Bond fart in bed? Because it would blow his cover."
),
DocumentMeta.create_text_document_from_literal(
"Why programmers don't like to swim? Because they're scared of the floating points."
),
]

config = {
"embedder": {"type": "LiteLLMEmbeddings"},
"vector_store": {"type": "InMemoryVectorStore"},
"reranker": {
"type": "packages.ragbits-document-search.src.ragbits.document_search.retrieval.rerankers.noop:NoopReranker"
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
},
}


async def main():
"""Run the example."""

document_search = DocumentSearch.from_config(config)

for document in documents:
await document_search.ingest_document(document)

results = await document_search.search("I'm boiling my water and I need a joke")
print(results)


if __name__ == "__main__":
asyncio.run(main())
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@

from pydantic import BaseModel, Field

from ragbits.core.embeddings import get_embeddings
from ragbits.core.embeddings.base import Embeddings
from ragbits.core.vector_store import get_vector_store
from ragbits.core.vector_store.base import VectorStore
from ragbits.document_search.documents.document import Document, DocumentMeta
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter
from ragbits.document_search.ingestion.providers.base import BaseProvider
from ragbits.document_search.retrieval.rephrasers import get_rephraser
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers import get_reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker

Expand Down Expand Up @@ -58,6 +62,25 @@ def __init__(
self.reranker = reranker or NoopReranker()
self.document_processor_router = document_processor_router or DocumentProcessorRouter.from_config()

@classmethod
def from_config(cls, config: dict) -> "DocumentSearch":
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates and returns an instance of the DocumentSearch class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the DocumentSearch instance.

Returns:
DocumentSearch: An initialized instance of the DocumentSearch class.
"""

embedder = get_embeddings(config["embedder"])
query_rephraser = get_rephraser(config.get("rephraser"))
reranker = get_reranker(config.get("reranker"))
vector_store = get_vector_store(config["vector_store"])

return cls(embedder, vector_store, query_rephraser, reranker)

async def search(self, query: str, search_config: SearchConfig = SearchConfig()) -> list[Element]:
"""
Search for the most relevant chunks for a query.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from typing import Optional

from ragbits.core.utils import get_cls_from_config

from .base import QueryRephraser
from .noop import NoopQueryRephraser

__all__ = ["NoopQueryRephraser", "QueryRephraser"]

module = sys.modules[__name__]


def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser:
"""
Initializes and returns a QueryRephraser object based on the provided configuration.

Args:
rephraser_config: A dictionary containing configuration details for the QueryRephraser.

Returns:
An instance of the specified QueryRephraser class, initialized with the provided config
(if any) or default arguments.
"""

if rephraser_config is None:
return NoopQueryRephraser()

rephraser_cls = get_cls_from_config(rephraser_config["type"], module)
config = rephraser_config.get("config", {})

return rephraser_cls(**config)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
from typing import Optional

from ragbits.core.utils import get_cls_from_config

from .base import Reranker
from .noop import NoopReranker

__all__ = ["NoopReranker", "Reranker"]

module = sys.modules[__name__]


def get_reranker(reranker_config: Optional[dict]) -> Reranker:
"""
Initializes and returns a Reranker object based on the provided configuration.

Args:
reranker_config: A dictionary containing configuration details for the Reranker.

Returns:
An instance of the specified Reranker class, initialized with the provided config
(if any) or default arguments.
"""

if reranker_config is None:
return NoopReranker()

reranker_cls = get_cls_from_config(reranker_config["type"], module)
config = reranker_config.get("config", {})

return reranker_cls(**config)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,30 @@
from ragbits.document_search.documents.element import TextElement
from ragbits.document_search.ingestion.providers.dummy import DummyProvider

CONFIG = {
"embedder": {"type": "NoopEmbeddings"},
"vector_store": {"type": "packages.ragbits-core.src.ragbits.core.vector_store.in_memory:InMemoryVectorStore"},
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
"reranker": {"type": "NoopReranker"},
}


@pytest.mark.parametrize(
"document",
[
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George"),
mhordynski marked this conversation as resolved.
Show resolved Hide resolved
],
)
async def test_document_search_from_config(document):
document_search = DocumentSearch.from_config(CONFIG)

await document_search.ingest_document(document, document_processor=DummyProvider())
results = await document_search.search("Peppa's brother")

first_result = results[0]

assert isinstance(first_result, TextElement)
assert first_result.content == "Name of Peppa's brother is George"


@pytest.mark.parametrize(
"document",
Expand Down
Loading