Skip to content

Commit

Permalink
feat(core): user should be able to configure default types of default…
Browse files Browse the repository at this point in the history
… LLMs
  • Loading branch information
konrad-czarnota-ds committed Oct 28, 2024
1 parent 5c43e87 commit 88e96e2
Show file tree
Hide file tree
Showing 10 changed files with 80 additions and 22 deletions.
8 changes: 5 additions & 3 deletions packages/ragbits-core/src/ragbits/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from rich import print as pprint

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLMType
from ragbits.core.prompt.prompt import Prompt


Expand Down Expand Up @@ -38,7 +39,7 @@ def register(app: typer.Typer) -> None:
@prompts_app.command()
def lab(
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down Expand Up @@ -73,15 +74,16 @@ def render(prompt_path: str, payload: str | None = None) -> None:

@prompts_app.command(name="exec")
def execute(
prompt_path: str, payload: str | None = None, llm_factory: str | None = core_config.default_llm_factory
prompt_path: str,
payload: str | None = None,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Executes a prompt using the specified prompt class and LLM factory.
Raises:
ValueError: If `llm_factory` is not provided.
"""

from ragbits.core.llms.factory import get_llm_from_factory

prompt = _render(prompt_path=prompt_path, payload=payload)
Expand Down
9 changes: 7 additions & 2 deletions packages/ragbits-core/src/ragbits/core/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel

from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance


Expand All @@ -11,8 +12,12 @@ class CoreConfig(BaseModel):
# Pattern used to search for prompt files
prompt_path_pattern: str = "**/prompt_*.py"

# Path to a function that returns an LLM object, e.g. "my_project.llms.get_llm"
default_llm_factory: str | None = None
# Path to a functions that returns LLM objects, e.g. "my_project.llms.get_llm"
default_llm_factories: dict[LLMType, str | None] = {
LLMType.TEXT: None,
LLMType.VISION: None,
LLMType.STRUCTURED_OUTPUT: None,
}


core_config = get_config_instance(CoreConfig, subproject="core")
11 changes: 11 additions & 0 deletions packages/ragbits-core/src/ragbits/core/llms/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
from abc import ABC, abstractmethod
from functools import cached_property
from typing import Generic, cast, overload
Expand All @@ -7,6 +8,16 @@
from .clients.base import LLMClient, LLMClientOptions, LLMOptions


class LLMType(enum.Enum):
"""
Types of LLMs based on supported features
"""

TEXT = "text"
VISION = "vision"
STRUCTURED_OUTPUT = "structured_output"


class LLM(Generic[LLMClientOptions], ABC):
"""
Abstract class for interaction with Large Language Model.
Expand Down
29 changes: 23 additions & 6 deletions packages/ragbits-core/src/ragbits/core/llms/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib

from ragbits.core.config import core_config
from ragbits.core.llms.base import LLM
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.litellm import LiteLLM


Expand All @@ -21,28 +21,34 @@ def get_llm_from_factory(factory_path: str) -> LLM:
return function()


def has_default_llm() -> bool:
def has_default_llm(llm_type: LLMType = LLMType.TEXT) -> bool:
"""
Check if the default LLM factory is set in the configuration.
Returns:
bool: Whether the default LLM factory is set.
"""
return core_config.default_llm_factory is not None
default_factory = core_config.default_llm_factories.get(llm_type, None)
return default_factory is not None


def get_default_llm() -> LLM:
def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
"""
Get an instance of the default LLM using the factory function
specified in the configuration.
Args:
llm_type: type of the LLM to get, defaults to text
Returns:
LLM: An instance of the default LLM.
Raises:
ValueError: If the default LLM factory is not set.
ValueError: If the default LLM factory is not set or expected llm type is not defined in config
"""
factory = core_config.default_llm_factory
if llm_type not in core_config.default_llm_factories:
raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.")
factory = core_config.default_llm_factories[llm_type]
if factory is None:
raise ValueError("Default LLM factory is not set")

Expand All @@ -58,3 +64,14 @@ def simple_litellm_factory() -> LLM:
LLM: An instance of the LiteLLM.
"""
return LiteLLM()


def simple_litellm_vision_factory() -> LLM:
"""
A basic LLM factory that creates an LiteLLM instance with the vision enabled model,
default options, and assumes that the API key is set in the environment.
Returns:
LLM: An instance of the LiteLLM.
"""
return LiteLLM(model_name="gpt-4o-mini")
3 changes: 2 additions & 1 deletion packages/ragbits-core/src/ragbits/core/prompt/lab/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from ragbits.core.config import core_config
from ragbits.core.llms import LLM
from ragbits.core.llms.base import LLMType
from ragbits.core.llms.factory import get_llm_from_factory
from ragbits.core.prompt import Prompt
from ragbits.core.prompt.discovery import PromptDiscovery
Expand Down Expand Up @@ -137,7 +138,7 @@ def get_input_type_fields(obj: BaseModel | None) -> list[dict]:

def lab_app( # pylint: disable=missing-param-doc
file_pattern: str = core_config.prompt_path_pattern,
llm_factory: str | None = core_config.default_llm_factory,
llm_factory: str | None = core_config.default_llm_factories[LLMType.TEXT],
) -> None:
"""
Launches the interactive application for listing, rendering, and testing prompts
Expand Down
15 changes: 14 additions & 1 deletion packages/ragbits-core/src/ragbits/core/utils/_pyproject.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import enum
import importlib
from pathlib import Path
from typing import Any, TypeVar

Expand Down Expand Up @@ -64,7 +66,7 @@ def get_config_instance(
model: type[ConfigModelT], subproject: str | None = None, current_dir: Path | None = None
) -> ConfigModelT:
"""
Creates an instace of pydantic model loaded with the configuration from pyproject.toml.
Creates an instance of pydantic model loaded with the configuration from pyproject.toml.
Args:
model (Type[BaseModel]): The pydantic model to instantiate.
Expand All @@ -81,4 +83,15 @@ def get_config_instance(
config = get_ragbits_config(current_dir)
if subproject:
config = config.get(subproject, {})
config["default_llm_factories"] = {_resolve_enum_member(k): v for k, v in config["default_llm_factories"].items()}
return model(**config)


def _resolve_enum_member(enum_string: str) -> enum.Enum:
module_name, class_name, member_name = enum_string.rsplit(".", 2)
module = importlib.import_module(module_name)
enum_class = getattr(module, class_name)
try:
return getattr(enum_class, member_name)
except AttributeError as err:
raise ValueError("Unsupported LLMType provided in default_llm_factories in pyproject.yaml") from err
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the get_llm_from_factory function.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "factory.test_get_llm_from_factory.mock_llm_factory")
monkeypatch.setattr(
core_config, "default_llm_factories", {"text": "factory.test_get_llm_from_factory.mock_llm_factory"}
)

llm = get_default_llm()

Check failure on line 16 in packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_get_default_llm

ValueError: Default LLM of type LLMType.TEXT is not defined in pyproject.toml config.
Raw output
monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x7fa0ffde91e0>

    def test_get_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
        """
        Test the get_llm_from_factory function.
        """
        monkeypatch.setattr(
            core_config, "default_llm_factories", {"text": "factory.test_get_llm_from_factory.mock_llm_factory"}
        )
    
>       llm = get_default_llm()

packages/ragbits-core/tests/unit/llms/factory/test_get_default_llm.py:16: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

llm_type = <LLMType.TEXT: 'text'>

    def get_default_llm(llm_type: LLMType = LLMType.TEXT) -> LLM:
        """
        Get an instance of the default LLM using the factory function
        specified in the configuration.
    
        Args:
            llm_type: type of the LLM to get, defaults to text
    
        Returns:
            LLM: An instance of the default LLM.
    
        Raises:
            ValueError: If the default LLM factory is not set or expected llm type is not defined in config
        """
        if llm_type not in core_config.default_llm_factories:
>           raise ValueError(f"Default LLM of type {llm_type} is not defined in pyproject.toml config.")
E           ValueError: Default LLM of type LLMType.TEXT is not defined in pyproject.toml config.

packages/ragbits-core/src/ragbits/core/llms/factory.py:50: ValueError
assert isinstance(llm, LiteLLM)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_has_default_llm(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the has_default_llm function when the default LLM factory is not set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", None)
monkeypatch.setattr(core_config, "default_llm_factories", {})

assert has_default_llm() is False

Expand All @@ -17,6 +17,6 @@ def test_has_default_llm_false(monkeypatch: pytest.MonkeyPatch) -> None:
"""
Test the has_default_llm function when the default LLM factory is set.
"""
monkeypatch.setattr(core_config, "default_llm_factory", "my_project.llms.get_llm")
monkeypatch.setattr(core_config, "default_llm_factories", {"text": "my_project.llms.get_llm"})

assert has_default_llm() is True

Check failure on line 22 in packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py

View workflow job for this annotation

GitHub Actions / JUnit Test Report

test_has_default_llm.test_has_default_llm_false

assert False is True + where False = has_default_llm()
Raw output
monkeypatch = <_pytest.monkeypatch.MonkeyPatch object at 0x7fa031585510>

    def test_has_default_llm_false(monkeypatch: pytest.MonkeyPatch) -> None:
        """
        Test the has_default_llm function when the default LLM factory is set.
        """
        monkeypatch.setattr(core_config, "default_llm_factories", {"text": "my_project.llms.get_llm"})
    
>       assert has_default_llm() is True
E       assert False is True
E        +  where False = has_default_llm()

packages/ragbits-core/tests/unit/llms/factory/test_has_default_llm.py:22: AssertionError
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.documents.elements import ElementType

from ragbits.core.llms.base import LLM
from ragbits.core.llms.litellm import LiteLLM
from ragbits.core.llms.base import LLM, LLMType
from ragbits.core.llms.factory import get_default_llm
from ragbits.document_search.documents.document import DocumentMeta, DocumentType
from ragbits.document_search.documents.element import Element, ImageElement
from ragbits.document_search.ingestion.providers.unstructured.default import UnstructuredDefaultProvider
Expand All @@ -17,8 +17,6 @@
to_text_element,
)

DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL = "gpt-4o-mini"


class UnstructuredImageProvider(UnstructuredDefaultProvider):
"""
Expand Down Expand Up @@ -53,7 +51,8 @@ def __init__(
llm: llm to use
"""
super().__init__(partition_kwargs, chunking_kwargs, api_key, api_server, use_api)
self.image_summarizer = ImageDescriber(llm or LiteLLM(DEFAULT_LLM_IMAGE_SUMMARIZATION_MODEL))
self.image_describer: ImageDescriber | None = None
self._llm = llm

async def _chunk_and_convert(
self, elements: list[UnstructuredElement], document_meta: DocumentMeta, document_path: Path
Expand All @@ -79,7 +78,10 @@ async def _to_image_element(
)

img_bytes = crop_and_convert_to_bytes(image, top_x, top_y, bottom_x, bottom_y)
image_description = await self.image_summarizer.get_image_description(img_bytes)
if self.image_describer is None:
llm_to_use = self._llm if self._llm is not None else get_default_llm(LLMType.VISION)
self.image_describer = ImageDescriber(llm_to_use)
image_description = await self.image_describer.get_image_description(img_bytes)
return ImageElement(
description=image_description,
ocr_extracted_text=element.text,
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,8 @@ known-third-party = [
"requests", "scipy", "setuptools", "shapely", "skimage", "sklearn", "streamlit",
"torch", "torchvision", "tqdm", "typer"
]

[tool.ragbits.core.default_llm_factories]
"ragbits.core.llms.base.LLMType.TEXT" = "ragbits.core.llms.factory.simple_litellm_factory"
"ragbits.core.llms.base.LLMType.VISION" = "ragbits.core.llms.factory.simple_litellm_vision_factory"
"ragbits.core.llms.base.LLMType.STRUCTURED_OUTPUT" = "ragbits.core.llms.factory.simple_litellm_vision_factory"

0 comments on commit 88e96e2

Please sign in to comment.