Skip to content

Commit

Permalink
make SourceDataProviders fully async (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
hf-kklein authored May 4, 2023
1 parent edf5e0b commit 5b67a48
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/bomf/filter/sourcedataproviderfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async def apply(
However, in general, you have to specify how the data can be indexed using a key_selector which is not None.
If you provide both a JsonFileSourceDataProvider AND a key_selector, the explicit key_selector will be used.
"""
survivors: List[Candidate] = await self._filter.apply(source_data_provider.get_data())
survivors: List[Candidate] = await self._filter.apply(await source_data_provider.get_data())
key_selector_to_be_used: Callable[[Candidate], KeyTyp]
if key_selector is not None:
key_selector_to_be_used = key_selector
Expand Down
12 changes: 6 additions & 6 deletions src/bomf/provider/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class SourceDataProvider(ABC, Generic[SourceDataModel, KeyTyp]):
"""

@abstractmethod
def get_data(self) -> List[SourceDataModel]:
async def get_data(self) -> List[SourceDataModel]:
"""
Returns all available entities from the source data model.
They will be filtered in a SourceDataModel Filter ("Preselect")
"""

@abstractmethod
def get_entry(self, key: KeyTyp) -> SourceDataModel:
async def get_entry(self, key: KeyTyp) -> SourceDataModel:
"""
returns the source data model which has key as key.
raises an error if the key is unknown
Expand Down Expand Up @@ -66,10 +66,10 @@ def __init__(self, source_data_models: List[SourceDataModel], key_selector: Call
)
self.key_selector = key_selector

def get_entry(self, key: KeyTyp) -> SourceDataModel:
async def get_entry(self, key: KeyTyp) -> SourceDataModel:
return self._models_dict[key]

def get_data(self) -> List[SourceDataModel]:
async def get_data(self) -> List[SourceDataModel]:
return self._models


Expand Down Expand Up @@ -97,8 +97,8 @@ def __init__(
}
self.key_selector = key_selector

def get_data(self) -> List[SourceDataModel]:
async def get_data(self) -> List[SourceDataModel]:
return self._source_data_models

def get_entry(self, key: KeyTyp) -> SourceDataModel:
async def get_entry(self, key: KeyTyp) -> SourceDataModel:
return self._key_to_data_model_mapping[key]
2 changes: 1 addition & 1 deletion unittests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def test_source_data_provider_filter(
caplog.set_level(logging.DEBUG, logger=self.__module__)
filtered_provider = await sdp_filter.apply(my_provider)
assert isinstance(filtered_provider, SourceDataProvider)
actual = filtered_provider.get_data()
actual = await filtered_provider.get_data()
assert actual == survivors
assert "There are 4 candidates and 4 aggregates" in caplog.messages
assert "There are 2 filtered aggregates left" in caplog.messages
Expand Down
6 changes: 3 additions & 3 deletions unittests/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def get_id(self) -> str:


class _MySourceDataProvider(SourceDataProvider[_MySourceDataModel, _MyKeyTyp]):
def get_entry(self, key: KeyTyp) -> _MySourceDataModel:
async def get_entry(self, key: KeyTyp) -> _MySourceDataModel:
raise NotImplementedError("Not relevant for the test")

def get_data(self) -> List[_MySourceDataModel]:
async def get_data(self) -> List[_MySourceDataModel]:
return [
{"foo": "bar"},
{"FOO": "BAR"},
Expand Down Expand Up @@ -104,7 +104,7 @@ class TestMigrationStrategy:

async def test_happy_path(self):
# here's some pre-processing, you can read some data, you can create relations, whatever
raw_data = _MySourceDataProvider().get_data()
raw_data = await _MySourceDataProvider().get_data()
survivors = await _MyFilter().apply(raw_data)
to_bo4e_mapper = _MyToBo4eMapper(what_ever_you_like=survivors)
strategy = MyMigrationStrategy(
Expand Down
26 changes: 13 additions & 13 deletions unittests/test_source_data_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,48 @@ class LegacyDataSystemDataProvider(SourceDataProvider):
a dummy for access to a legacy system from which we want to migrate data
"""

def get_entry(self, key: KeyTyp) -> str:
async def get_entry(self, key: KeyTyp) -> str:
raise NotImplementedError("Not relevant for this test")

def get_data(self) -> List[str]:
async def get_data(self) -> List[str]:
return ["foo", "bar", "baz"]


class TestSourceDataProvider:
def test_provider(self):
async def test_provider(self):
# this is a pretty dumb test
provider_under_test = LegacyDataSystemDataProvider()
assert isinstance(provider_under_test.get_data(), list)
assert isinstance(await provider_under_test.get_data(), list)

@pytest.mark.datafiles("./unittests/example_source_data.json")
def test_json_file_provider(self, datafiles):
async def test_json_file_provider(self, datafiles):
file_path = datafiles / Path("example_source_data.json")
example_json_data_provider = JsonFileSourceDataProvider(
file_path,
data_selector=lambda d: d["data"], # type:ignore[call-overload]
key_selector=lambda d: d["myKey"], # type:ignore[index]
)
assert example_json_data_provider.get_data() == [
assert await example_json_data_provider.get_data() == [
{"myKey": "hello", "asd": "fgh"},
{"myKey": "world", "qwe": "rtz"},
]
assert example_json_data_provider.get_entry("world") == {"myKey": "world", "qwe": "rtz"}
assert await example_json_data_provider.get_entry("world") == {"myKey": "world", "qwe": "rtz"}
with pytest.raises(KeyError):
_ = example_json_data_provider.get_entry("something unknown")
_ = await example_json_data_provider.get_entry("something unknown")


class TestListBasedSourceDataProvider:
def test_list_based_provider(self, caplog):
async def test_list_based_provider(self, caplog):
caplog.set_level(logging.DEBUG, logger=ListBasedSourceDataProvider.__module__)
my_provider = ListBasedSourceDataProvider(["foo", "bar", "baz"], key_selector=lambda x: x)
assert len(my_provider.get_data()) == 3
assert my_provider.get_entry("bar") == "bar"
assert len(await my_provider.get_data()) == 3
assert await my_provider.get_entry("bar") == "bar"
assert "Read 3 records from ['foo', 'bar', 'baz']" in caplog.messages

def test_list_based_provider_key_warning(self, caplog):
async def test_list_based_provider_key_warning(self, caplog):
caplog.set_level(logging.WARNING, logger=ListBasedSourceDataProvider.__module__)
my_provider = ListBasedSourceDataProvider(["fooy", "fooz" "bar", "baz"], key_selector=lambda x: x[0:3])
assert len(my_provider.get_data()) == 3
assert len(await my_provider.get_data()) == 3
assert (
"There are 2>1 entries for the key 'foo'. You might miss entries because the key is not unique."
in caplog.messages
Expand Down

0 comments on commit 5b67a48

Please sign in to comment.