Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds committed Oct 28, 2024
1 parent 4a847f2 commit 63abdd0
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from pathlib import Path

import pytest
from pydantic import BaseModel

from ragbits.core.config import CoreConfig
from ragbits.core.llms.base import LLMType
from ragbits.core.utils._pyproject import get_config_instance

projects_dir = Path(__file__).parent / "testprojects"
Expand Down Expand Up @@ -66,3 +69,30 @@ def test_get_config_instance_no_file():
)

assert config == OptionalHappyProjectConfig()


def test_get_config_instance_factories():
"""Test that default LLMs are loaded correctly"""
config = get_config_instance(
CoreConfig,
subproject="core",
current_dir=projects_dir / "factory_project",
)

assert config.default_llm_factories == {
LLMType.TEXT: "ragbits.core.llms.factory.simple_litellm_factory",
LLMType.VISION: "ragbits.core.llms.factory.simple_litellm_vision_factory",
LLMType.STRUCTURED_OUTPUT: "ragbits.core.llms.factory.simple_litellm_vision_factory",
}


def test_get_config_instance_bad_factories():
"""Test that non-existing LLM defined in pyproject raises error"""
with pytest.raises(ValueError) as err:
get_config_instance(
CoreConfig,
subproject="core",
current_dir=projects_dir / "bad_factory_project",
)

assert "Unsupported LLMType provided in default_llm_factories in pyproject.yaml" in str(err.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[project]
name = "bad_factory_project"

[tool.ragbits.core.default_llm_factories]
"ragbits.core.llms.base.LLMType.NON_EXISTING" = "ragbits.core.llms.factory.simple_litellm_factory"
"ragbits.core.llms.base.LLMType.VISION" = "ragbits.core.llms.factory.simple_litellm_vision_factory"
"ragbits.core.llms.base.LLMType.STRUCTURED_OUTPUT" = "ragbits.core.llms.factory.simple_litellm_vision_factory"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[project]
name = "factory_project"

[tool.ragbits.core.default_llm_factories]
"ragbits.core.llms.base.LLMType.TEXT" = "ragbits.core.llms.factory.simple_litellm_factory"
"ragbits.core.llms.base.LLMType.VISION" = "ragbits.core.llms.factory.simple_litellm_vision_factory"
"ragbits.core.llms.base.LLMType.STRUCTURED_OUTPUT" = "ragbits.core.llms.factory.simple_litellm_vision_factory"

0 comments on commit 63abdd0

Please sign in to comment.