Skip to content

Commit

Permalink
Fixed issue with null values erroring out aggregators
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Feb 1, 2024
1 parent d7df15a commit f6cd391
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) (+ the Migration Guide),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.31.1] - 2024-02-01

### Fixed

- TypedIterator now skips aggregation when no values are provided

## [0.31.0] - 2024-01-31

## Changed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tpcp"
version = "0.31.0"
version = "0.31.1"
description = "Pipeline and Dataset helpers for complex algorithm evaluation."
authors = [
"Arne Küderle <[email protected]>",
Expand Down
20 changes: 20 additions & 0 deletions tests/test_misc/test_typed_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,23 @@ def test_not_allowed_attr_error():

with pytest.raises(ValueError):
[next(iterator.iterate(data)) for _ in range(3)]

def test_agg_with_empty():
rt = make_dataclass("ResultType", ["result_1", "result_2", "result_3"])

iterator = TypedIterator[rt](
rt, aggregations=[("result_1", lambda i, r: sum(i)), ("result_2", lambda i, r: sum(r))]
)

data = [1, 2, 3]
for i, r in iterator.iterate(data):
r.result_1 = i - 1
# We Don't set result 2 -> it will remain an empty value and should skip agg
r.result_3 = i * 3

result_obj = iterator.results_

assert isinstance(result_obj, rt)
assert iterator.result_1_ == result_obj.result_1 == 6
assert iterator.result_2_ == [iterator.NULL_VALUE, iterator.NULL_VALUE, iterator.NULL_VALUE]
assert iterator.result_3_ == result_obj.result_3 == [3, 6, 9]
2 changes: 1 addition & 1 deletion tpcp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from tpcp._pipeline import OptimizablePipeline, Pipeline

__version__ = "0.31.0"
__version__ = "0.31.1"


__all__ = [
Expand Down
2 changes: 1 addition & 1 deletion tpcp/misc/_typed_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _agg_result(self, name):
values = [getattr(r, name) for r in self.raw_results_]
# if an aggregator is defined for the specific item, we apply it
aggregations = dict(self.aggregations)
if name in aggregations:
if name in aggregations and all(v != self.NULL_VALUE for v in values):
return aggregations[name](self.inputs_, values)
return values

Expand Down

0 comments on commit f6cd391

Please sign in to comment.