diff --git a/src/bomf/loader/entityloader.py b/src/bomf/loader/entityloader.py index a10d727..c06d391 100644 --- a/src/bomf/loader/entityloader.py +++ b/src/bomf/loader/entityloader.py @@ -2,11 +2,14 @@ entity loaders load entities into the target system """ import asyncio +import json from abc import ABC, abstractmethod from datetime import datetime -from typing import Awaitable, Generic, List, Optional, TypeVar +from pathlib import Path +from typing import Awaitable, Callable, Generic, List, Optional, TypeVar import attrs +from pydantic import BaseModel # pylint:disable=no-name-in-module _TargetEntity = TypeVar("_TargetEntity") @@ -136,3 +139,51 @@ async def load_entities(self, entities: List[_TargetEntity]) -> List[LoadingSumm tasks: List[Awaitable[LoadingSummary]] = [self.load(entity) for entity in entities] result = await asyncio.gather(*tasks) return result + + +class JsonFileEntityLoader(EntityLoader[_TargetEntity], Generic[_TargetEntity]): + """ + an entity loader that produces a json file as result. This is specifically useful in unit tests + """ + + async def verify(self, entity: _TargetEntity, id_in_target_system: Optional[str] = None) -> bool: + return True + + def __init__(self, file_path: Path, list_encoder: Callable[[List[_TargetEntity]], List[dict]]): + """provide a path to a json file (will be created if not exists and overwritten if exists)""" + self._file_path = file_path + self._list_encoder = list_encoder + self._entities: List[_TargetEntity] = [] + + async def load_entity(self, entity: _TargetEntity) -> Optional[EntityLoadingResult]: + self._entities.append(entity) + return None + + async def load_entities(self, entities: List[_TargetEntity]) -> List[LoadingSummary]: + base_result = await super().load_entities(entities) + dict_list = self._list_encoder(self._entities) + with open(self._file_path, "w+", encoding="utf-8") as outfile: + json.dump(dict_list, outfile, ensure_ascii=False, indent=2) + return base_result + + +_PydanticTargetModel = TypeVar("_PydanticTargetModel", bound=BaseModel) + + +# pylint:disable=too-few-public-methods +class _ListOfPydanticModels(BaseModel, Generic[_PydanticTargetModel]): + # https://stackoverflow.com/a/58641115/10009545 + # for the instantiation see the serialization unit test + __root__: List[_PydanticTargetModel] + + +class PydanticJsonFileEntityLoader(JsonFileEntityLoader[_PydanticTargetModel], Generic[_PydanticTargetModel]): + """ + A json file entity loader specifically for pydantic models + """ + + def __init__(self, file_path: Path): + """provide a file path""" + super().__init__( + file_path=file_path, list_encoder=lambda x: json.loads(_ListOfPydanticModels(__root__=x).json()) + ) diff --git a/unittests/test_entity_loader.py b/unittests/test_entity_loader.py index b084950..30e3d02 100644 --- a/unittests/test_entity_loader.py +++ b/unittests/test_entity_loader.py @@ -1,6 +1,10 @@ +import json +from pathlib import Path from typing import Optional -from bomf.loader.entityloader import EntityLoader, EntityLoadingResult +from pydantic import BaseModel + +from bomf.loader.entityloader import EntityLoader, EntityLoadingResult, PydanticJsonFileEntityLoader class _ExampleEntity: @@ -103,3 +107,25 @@ async def load_entity(self, entity: _ExampleEntity) -> Optional[EntityLoadingRes assert result.loaded_at is None assert result.verified_at is None assert isinstance(result.loading_error, ValueError) is True + + +class MyPydanticClass(BaseModel): + foo: str + bar: int + + +class MyLoader(PydanticJsonFileEntityLoader[MyPydanticClass]): + """entity loader fo my pydantic class""" + + +class TestJsonFileEntityLoader: + async def test_dumping_to_file(self, tmp_path): + my_entities = [MyPydanticClass(foo="asd", bar=123), MyPydanticClass(foo="qwe", bar=456)] + file_path = Path(tmp_path) / Path("foo.json") + my_loader = MyLoader(file_path) + await my_loader.load_entities(my_entities) + del my_loader + with open(file_path, "r", encoding="utf-8") as infile: + json_body = json.load(infile) + assert len(json_body) == 2 + assert json_body == [{"foo": "asd", "bar": 123}, {"foo": "qwe", "bar": 456}]