Skip to content

Commit

Permalink
feat(document-search): allow to use local instance of unstructured (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad-czarnota-ds authored Oct 8, 2024
1 parent 2e95436 commit a774147
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from io import BytesIO
from typing import Optional

from unstructured.chunking.basic import chunk_elements
from unstructured.documents.elements import Element as UnstructuredElement
from unstructured.partition.auto import partition
from unstructured.staging.base import elements_from_dicts
from unstructured_client import UnstructuredClient

Expand Down Expand Up @@ -56,6 +58,7 @@ def __init__(
chunking_kwargs: Optional[dict] = None,
api_key: Optional[str] = None,
api_server: Optional[str] = None,
use_api: bool = False,
) -> None:
"""Initialize the UnstructuredProvider.
Expand All @@ -72,6 +75,7 @@ def __init__(
self.chunking_kwargs = chunking_kwargs or DEFAULT_CHUNKING_KWARGS
self.api_key = api_key
self.api_server = api_server
self.use_api = use_api
self._client = None

@property
Expand Down Expand Up @@ -108,18 +112,27 @@ async def process(self, document_meta: DocumentMeta) -> list[Element]:
self.validate_document_type(document_meta.document_type)
document = await document_meta.fetch()

res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
if self.use_api:
res = await self.client.general.partition_async(
request={
"partition_parameters": {
"files": {
"content": document.local_path.read_bytes(),
"file_name": document.local_path.name,
},
**self.partition_kwargs,
}
}
}
)
elements = chunk_elements(elements_from_dicts(res.elements), **self.chunking_kwargs)
)
elements = elements_from_dicts(res.elements)
else:
elements = partition(
file=BytesIO(document.local_path.read_bytes()),
metadata_filename=document.local_path.name,
**self.partition_kwargs,
)

elements = chunk_elements(elements, **self.chunking_kwargs)
return [_to_text_element(element, document_meta) for element in elements]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,21 @@
from ..helpers import env_vars_not_set


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"config",
[
{},
pytest.param(
{DocumentType.TXT: UnstructuredProvider(use_api=True)},
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_document_processor_processes_text_document_with_unstructured_provider():
document_processor = DocumentProcessorRouter.from_config()
async def test_document_processor_processes_text_document_with_unstructured_provider(config):
document_processor = DocumentProcessorRouter.from_config(config)
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")

elements = await document_processor.get_provider(document_meta).process(document_meta)
Expand All @@ -43,28 +52,46 @@ async def test_document_processor_processes_md_document_with_unstructured_provid
assert elements[0].content == "Ragbits\n\nRepository for internal experiment with our upcoming LLM framework."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"use_api",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_unstructured_provider_document_with_default_partition_kwargs():
async def test_unstructured_provider_document_with_default_partition_kwargs(use_api):
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
unstructured_provider = UnstructuredProvider()
unstructured_provider = UnstructuredProvider(use_api=use_api)
elements = await unstructured_provider.process(document_meta)

assert unstructured_provider.partition_kwargs == DEFAULT_PARTITION_KWARGS
assert len(elements) == 1
assert elements[0].content == "Name of Peppa's brother is George."


@pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
@pytest.mark.parametrize(
"use_api",
[
False,
pytest.param(
True,
marks=pytest.mark.skipif(
env_vars_not_set([UNSTRUCTURED_SERVER_URL_ENV, UNSTRUCTURED_API_KEY_ENV]),
reason="Unstructured API environment variables not set",
),
),
],
)
async def test_unstructured_provider_document_with_custom_partition_kwargs():
async def test_unstructured_provider_document_with_custom_partition_kwargs(use_api):
document_meta = DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
partition_kwargs = {"languages": ["pl"], "strategy": "fast"}
unstructured_provider = UnstructuredProvider(partition_kwargs=partition_kwargs)
unstructured_provider = UnstructuredProvider(use_api=use_api, partition_kwargs=partition_kwargs)
elements = await unstructured_provider.process(document_meta)

assert unstructured_provider.partition_kwargs == partition_kwargs
Expand Down
4 changes: 2 additions & 2 deletions packages/ragbits-document-search/tests/unit/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_unsupported_provider_validates_supported_document_types_fails():
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider().process(
await UnstructuredProvider(use_api=True).process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

Expand All @@ -33,7 +33,7 @@ async def test_unstructured_provider_raises_value_error_when_api_key_not_set():
@patch.dict(os.environ, {}, clear=True)
async def test_unstructured_provider_raises_value_error_when_server_url_not_set():
with pytest.raises(ValueError) as err:
await UnstructuredProvider(api_key="api_key").process(
await UnstructuredProvider(api_key="api_key", use_api=True).process(
DocumentMeta.create_text_document_from_literal("Name of Peppa's brother is George.")
)

Expand Down

0 comments on commit a774147

Please sign in to comment.