Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Jul 26, 2024
1 parent a6b5546 commit f943385
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
5 changes: 5 additions & 0 deletions examples/validation/_03_custom_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def run(self, datapoint: ECGExampleData):
# We will use a similar score function as we used in the QRS detection example.
# It returns the precision, recall and f1 score of the QRS detection for each datapoint.


def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# We use the `safe_run` wrapper instead of just run. This is always a good idea.
# We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
Expand Down Expand Up @@ -151,6 +152,7 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData):
#
# Below, only the F1-score will be aggregated by the median aggregator.


def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# We use the `safe_run` wrapper instead of just run. This is always a good idea.
# We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
Expand Down Expand Up @@ -235,6 +237,7 @@ def mean_and_std(vals: Sequence[float]):
# Note, that the actual aggregation is an instance of our custom class, NOT the class itself.
from tpcp.validate import Aggregator


class SingleValuePrecisionRecallF1(Aggregator[np.ndarray]):
def aggregate(self, /, values: Sequence[np.ndarray], **_) -> dict[str, float]:
print("SingleValuePrecisionRecallF1 Aggregator called")
Expand Down Expand Up @@ -281,6 +284,7 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# `return_raw_scores` class variable to False for our specific usecase.
single_value_precision_recall_f1_agg_no_raw = SingleValuePrecisionRecallF1(return_raw_scores=False)


def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# We use the `safe_run` wrapper instead of just run. This is always a good idea.
# We don't need to clone the pipeline here, as GridSearch will already clone the pipeline internally and `run`
Expand Down Expand Up @@ -323,6 +327,7 @@ def score(pipeline: MyPipeline, datapoint: ECGExampleData):
# Hence, we recommend to use these examples as a starting point to implement your own custom aggregators.
from typing import Callable, Union


class SingleValueAggregator(Aggregator[np.ndarray]):
def __init__(
self, func: Callable[[Sequence[np.ndarray]], Union[float, dict[str, float]]], *, return_raw_scores: bool = True
Expand Down
18 changes: 7 additions & 11 deletions tests/test_pipelines/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tpcp import Pipeline
from tpcp.exceptions import ScorerFailedError, ValidationError
from tpcp.validate import Scorer
from tpcp.validate._scorer import Aggregator, _passthrough_scoring, _validate_scorer, no_agg, FloatAggregator
from tpcp.validate._scorer import Aggregator, FloatAggregator, _passthrough_scoring, _validate_scorer, no_agg


class TestScorerCalls:
Expand Down Expand Up @@ -423,20 +423,24 @@ def callback(step, scores, **_):
def _return_1(x):
return 1


def _return_2(x):
return 2


def _return_3(x):
return 3


def _return_4(x):
return 4


def _return_5(x):
return 5

class TestWeirdScoringStuff:

class TestWeirdScoringStuff:
class DummyPipeline(Pipeline):
def __init__(self, values):
self.values = values
Expand All @@ -448,7 +452,6 @@ def get_value(self, dp):
_funcs = [_return_1, _return_2, _return_3, _return_4, _return_5]

def test_different_config_considered_different(self):

def score_func(pipeline, data_point):
return FloatAggregator(self._funcs[pipeline.get_value(data_point)])(1)

Expand All @@ -460,9 +463,8 @@ def score_func(pipeline, data_point):
assert "Based on the first value encountered" in str(e)

def test_same_config_considered_same(self):

def score_func(pipeline, data_point):
return FloatAggregator(self._funcs[3], return_raw_scores=False)(1)
return FloatAggregator(self._funcs[3], return_raw_scores=False)(1)

scorer = Scorer(score_func)

Expand All @@ -471,7 +473,6 @@ def score_func(pipeline, data_point):
assert agg_val == 4

def test_with_multiprocessing(self):

def score_func(pipeline, data_point):
return FloatAggregator(self._funcs[3], return_raw_scores=False)(1)

Expand All @@ -480,8 +481,3 @@ def score_func(pipeline, data_point):
agg_val, _ = scorer(self.DummyPipeline(list(range(5))), DummyDataset())

assert agg_val == 4





13 changes: 11 additions & 2 deletions tpcp/validate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
"""Module for all helper methods to evaluate algorithms."""
from tpcp.validate._cross_val_helper import DatasetSplitter
from tpcp.validate._scorer import Aggregator, Scorer, mean_agg, no_agg, FloatAggregator
from tpcp.validate._scorer import Aggregator, FloatAggregator, Scorer, mean_agg, no_agg
from tpcp.validate._validate import cross_validate, validate

__all__ = ["Scorer", "no_agg", "Aggregator", "mean_agg", "cross_validate", "validate", "DatasetSplitter", "FloatAggregator"]
__all__ = [
"Scorer",
"no_agg",
"Aggregator",
"mean_agg",
"cross_validate",
"validate",
"DatasetSplitter",
"FloatAggregator",
]
3 changes: 1 addition & 2 deletions tpcp/validate/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ def aggregate(self, /, values: Sequence[float], datapoints: Sequence[Dataset]) -
vals = self.func(values)
except TypeError as e:
raise ValidationError(
f"Applying the float aggregation function {self.func} failed. "
f"\n\n{values}"
f"Applying the float aggregation function {self.func} failed. " f"\n\n{values}"
) from e

if isinstance(vals, dict):
Expand Down

0 comments on commit f943385

Please sign in to comment.