Skip to content

Commit

Permalink
✨Introduce (Pydantic)JsonFileEntityLoader (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
hf-kklein authored Apr 3, 2023
1 parent 1b974f4 commit 73943e0
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
53 changes: 52 additions & 1 deletion src/bomf/loader/entityloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
entity loaders load entities into the target system
"""
import asyncio
import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Awaitable, Generic, List, Optional, TypeVar
from pathlib import Path
from typing import Awaitable, Callable, Generic, List, Optional, TypeVar

import attrs
from pydantic import BaseModel # pylint:disable=no-name-in-module

_TargetEntity = TypeVar("_TargetEntity")

Expand Down Expand Up @@ -136,3 +139,51 @@ async def load_entities(self, entities: List[_TargetEntity]) -> List[LoadingSumm
tasks: List[Awaitable[LoadingSummary]] = [self.load(entity) for entity in entities]
result = await asyncio.gather(*tasks)
return result


class JsonFileEntityLoader(EntityLoader[_TargetEntity], Generic[_TargetEntity]):
"""
an entity loader that produces a json file as result. This is specifically useful in unit tests
"""

async def verify(self, entity: _TargetEntity, id_in_target_system: Optional[str] = None) -> bool:
return True

def __init__(self, file_path: Path, list_encoder: Callable[[List[_TargetEntity]], List[dict]]):
"""provide a path to a json file (will be created if not exists and overwritten if exists)"""
self._file_path = file_path
self._list_encoder = list_encoder
self._entities: List[_TargetEntity] = []

async def load_entity(self, entity: _TargetEntity) -> Optional[EntityLoadingResult]:
self._entities.append(entity)
return None

async def load_entities(self, entities: List[_TargetEntity]) -> List[LoadingSummary]:
base_result = await super().load_entities(entities)
dict_list = self._list_encoder(self._entities)
with open(self._file_path, "w+", encoding="utf-8") as outfile:
json.dump(dict_list, outfile, ensure_ascii=False, indent=2)
return base_result


_PydanticTargetModel = TypeVar("_PydanticTargetModel", bound=BaseModel)


# pylint:disable=too-few-public-methods
class _ListOfPydanticModels(BaseModel, Generic[_PydanticTargetModel]):
# https://stackoverflow.com/a/58641115/10009545
# for the instantiation see the serialization unit test
__root__: List[_PydanticTargetModel]


class PydanticJsonFileEntityLoader(JsonFileEntityLoader[_PydanticTargetModel], Generic[_PydanticTargetModel]):
"""
A json file entity loader specifically for pydantic models
"""

def __init__(self, file_path: Path):
"""provide a file path"""
super().__init__(
file_path=file_path, list_encoder=lambda x: json.loads(_ListOfPydanticModels(__root__=x).json())
)
28 changes: 27 additions & 1 deletion unittests/test_entity_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
from pathlib import Path
from typing import Optional

from bomf.loader.entityloader import EntityLoader, EntityLoadingResult
from pydantic import BaseModel

from bomf.loader.entityloader import EntityLoader, EntityLoadingResult, PydanticJsonFileEntityLoader


class _ExampleEntity:
Expand Down Expand Up @@ -103,3 +107,25 @@ async def load_entity(self, entity: _ExampleEntity) -> Optional[EntityLoadingRes
assert result.loaded_at is None
assert result.verified_at is None
assert isinstance(result.loading_error, ValueError) is True


class MyPydanticClass(BaseModel):
foo: str
bar: int


class MyLoader(PydanticJsonFileEntityLoader[MyPydanticClass]):
"""entity loader fo my pydantic class"""


class TestJsonFileEntityLoader:
async def test_dumping_to_file(self, tmp_path):
my_entities = [MyPydanticClass(foo="asd", bar=123), MyPydanticClass(foo="qwe", bar=456)]
file_path = Path(tmp_path) / Path("foo.json")
my_loader = MyLoader(file_path)
await my_loader.load_entities(my_entities)
del my_loader
with open(file_path, "r", encoding="utf-8") as infile:
json_body = json.load(infile)
assert len(json_body) == 2
assert json_body == [{"foo": "asd", "bar": 123}, {"foo": "qwe", "bar": 456}]

0 comments on commit 73943e0

Please sign in to comment.