Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(document-search): Allow to create DocumentSearch instances from config #62

26 changes: 26 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import sys

from .base import Embeddings
from .litellm import LiteLLMEmbeddings
from .local import LocalEmbeddings

__all__ = ["LiteLLMEmbeddings", "LocalEmbeddings"]

module = sys.modules[__name__]


def get_embeddings(embedder_config: dict) -> Embeddings:
"""
Initializes and returns an Embeddings object based on the provided embedder configuration.

Args:
embedder_config : A dictionary containing configuration details for the embedder.

Returns:
An instance of the specified Embeddings class, initialized with the provided config
(if any) or default arguments.
"""
embeddings_type = embedder_config["type"]
config = embedder_config.get("config", {})

return getattr(module, embeddings_type)(**config)
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ragbits.core.embeddings import get_embeddings
from ragbits.core.embeddings.base import Embeddings
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element
from ragbits.document_search.ingestion.document_processor import DocumentProcessor
from ragbits.document_search.ingestion.providers.dummy import DummyProvider
from ragbits.document_search.retrieval.rephrasers import get_rephraser
from ragbits.document_search.retrieval.rephrasers.base import QueryRephraser
from ragbits.document_search.retrieval.rephrasers.noop import NoopQueryRephraser
from ragbits.document_search.retrieval.rerankers import get_reranker
from ragbits.document_search.retrieval.rerankers.base import Reranker
from ragbits.document_search.retrieval.rerankers.noop import NoopReranker
from ragbits.document_search.vector_store import get_vector_store
from ragbits.document_search.vector_store.base import VectorStore


Expand Down Expand Up @@ -42,6 +46,26 @@ def __init__(
self.query_rephraser = query_rephraser or NoopQueryRephraser()
self.reranker = reranker or NoopReranker()

@classmethod
def from_config(cls, config: dict) -> "DocumentSearch":
micpst marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates and returns an instance of the DocumentSearch class from the given configuration.

Args:
config: A dictionary containing the configuration for initializing the DocumentSearch instance.

Returns:
DocumentSearch: An initialized instance of the DocumentSearch class.
"""

embedder = get_embeddings(config["embedder"])
query_rephraser = get_rephraser(config.get("rephraser"))
reranker = get_reranker(config.get("reranker"))
vector_store = get_vector_store(config["vector_store"])

self = DocumentSearch(embedder, vector_store, query_rephraser, reranker)
return self
micpst marked this conversation as resolved.
Show resolved Hide resolved

async def search(self, query: str) -> list[Element]:
"""
Search for the most relevant chunks for a query.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
from typing import Optional

from ...utils import get_cls_from_config
from .base import QueryRephraser
from .noop import NoopQueryRephraser

__all__ = ["NoopQueryRephraser", "QueryRephraser"]

module = sys.modules[__name__]


def get_rephraser(rephraser_config: Optional[dict]) -> QueryRephraser:
"""
Initializes and returns a QueryRephraser object based on the provided configuration.

Args:
rephraser_config: A dictionary containing configuration details for the QueryRephraser.

Returns:
An instance of the specified QueryRephraser class, initialized with the provided config
(if any) or default arguments.
"""

if rephraser_config is None:
return NoopQueryRephraser()

rephraser_cls = get_cls_from_config(rephraser_config["type"], module)
config = rephraser_config.get("config", {})

return rephraser_cls(**config)
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import sys
from typing import Optional

from ...utils import get_cls_from_config
from .base import Reranker
from .noop import NoopReranker

__all__ = ["NoopReranker", "Reranker"]

module = sys.modules[__name__]


def get_reranker(reranker_config: Optional[dict]) -> Reranker:
"""
Initializes and returns a Reranker object based on the provided configuration.

Args:
reranker_config: A dictionary containing configuration details for the Reranker.

Returns:
An instance of the specified Reranker class, initialized with the provided config
(if any) or default arguments.
"""

if reranker_config is None:
return NoopReranker()

reranker_cls = get_cls_from_config(reranker_config["type"], module)
config = reranker_config.get("config", {})

return reranker_cls(**config)
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from importlib import import_module
from types import ModuleType
from typing import Any


def get_cls_from_config(cls_path: str, default_module: ModuleType) -> Any:
"""
Retrieves and returns a class based on the given type string. The class can be either in the
default module or a specified module if provided in the type string.

Args:
cls_path: A string representing the path to the class or object. This can either be a
path implicitly referencing the default module or a full path (module.submodule:ClassName)
if the class is located in a different module.
default_module: The default module to search for the class if no specific module
is provided in the type string.

Returns:
Any: The object retrieved from the specified or default module.
"""
if ":" in cls_path:
module_stringified, object_stringified = cls_path.split(":")
module = import_module(module_stringified)
return getattr(module, object_stringified)

return getattr(default_module, cls_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import sys

from ..utils import get_cls_from_config
from .base import VectorStore
from .chromadb_store import ChromaDBStore
from .in_memory import InMemoryVectorStore

__all__ = ["InMemoryVectorStore", "VectorStore", "ChromaDBStore"]

module = sys.modules[__name__]


def get_vector_store(vector_store_config: dict) -> VectorStore:
"""
Initializes and returns a VectorStore object based on the provided configuration.

Args:
vector_store_config: A dictionary containing configuration details for the VectorStore.

Returns:
An instance of the specified VectorStore class, initialized with the provided config
(if any) or default arguments.
"""

vector_store_cls = get_cls_from_config(vector_store_config["type"], module)
config = vector_store_config.get("config", {})

return vector_store_cls(**config)
Loading