From 53edf48d281928591df8561b0284858a85d2b329 Mon Sep 17 00:00:00 2001 From: Alan Konarski Date: Thu, 3 Oct 2024 12:01:27 +0200 Subject: [PATCH] Fix integration tests, update docstrings --- .../src/ragbits/document_search/_main.py | 6 +++--- .../tests/integration/test_file.md | 3 +++ .../tests/integration/test_unstructured.py | 18 +++++++++--------- 3 files changed, 15 insertions(+), 12 deletions(-) create mode 100644 packages/ragbits-document-search/tests/integration/test_file.md diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index b62a4836e..ae593a1ef 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -85,7 +85,7 @@ async def ingest_document( Ingest a document. Args: - document: The document or document metadata to ingest. + document: The document or metadata of the document to ingest. document_processor: The document processor to use. If not provided, the document processor will be determined based on the document metadata. """ @@ -98,10 +98,10 @@ async def ingest_document( async def insert_elements(self, elements: list[Element]) -> None: """ - Insert an elements into the vector store. + Insert Elements into the vector store. Args: - elements: The element to insert. + elements: The list of Elements to insert. """ vectors = await self.embedder.embed_text([element.get_key() for element in elements]) entries = [element.to_vector_db_entry(vector) for element, vector in zip(elements, vectors)] diff --git a/packages/ragbits-document-search/tests/integration/test_file.md b/packages/ragbits-document-search/tests/integration/test_file.md new file mode 100644 index 000000000..80554b5b3 --- /dev/null +++ b/packages/ragbits-document-search/tests/integration/test_file.md @@ -0,0 +1,3 @@ +# Ragbits + +Repository for internal experiment with our upcoming LLM framework. diff --git a/packages/ragbits-document-search/tests/integration/test_unstructured.py b/packages/ragbits-document-search/tests/integration/test_unstructured.py index a48c1f49b..b7827b857 100644 --- a/packages/ragbits-document-search/tests/integration/test_unstructured.py +++ b/packages/ragbits-document-search/tests/integration/test_unstructured.py @@ -3,7 +3,7 @@ import pytest from ragbits.document_search.documents.document import DocumentMeta, DocumentType -from ragbits.document_search.ingestion.document_processor import DocumentProcessor +from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.providers.unstructured import ( DEFAULT_PARTITION_KWARGS, UNSTRUCTURED_API_KEY_ENV, @@ -19,14 +19,14 @@ reason="Unstructured API environment variables not set", ) async def test_document_processor_processes_text_document_with_unstructured_provider(): - document_processor = DocumentProcessor.from_config() + document_processor = DocumentProcessorRouter.from_config() document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.") - elements = await document_processor.process(document_meta) + elements = await document_processor.get_provider(document_meta).process(document_meta) assert isinstance(document_processor._providers[DocumentType.TXT], UnstructuredProvider) assert len(elements) == 1 - assert elements[0].content == "Name of Peppa's brother is George" + assert elements[0].content == "Name of Peppa's brother is George." @pytest.mark.skipif( @@ -34,13 +34,13 @@ async def test_document_processor_processes_text_document_with_unstructured_prov reason="Unstructured API environment variables not set", ) async def test_document_processor_processes_md_document_with_unstructured_provider(): - document_processor = DocumentProcessor.from_config() - document_meta = DocumentMeta.from_local_path(Path(__file__).parent.parent.parent.parent.parent / "README.md") + document_processor = DocumentProcessorRouter.from_config() + document_meta = DocumentMeta.from_local_path(Path(__file__).parent / "test_file.md") - elements = await document_processor.process(document_meta) + elements = await document_processor.get_provider(document_meta).process(document_meta) - assert len(elements) > 0 - assert elements[0].content == "Ragbits" + assert len(elements) == 1 + assert elements[0].content == "Ragbits\n\nRepository for internal experiment with our upcoming LLM framework." @pytest.mark.skipif(