From 12a9b5e640273bfe3a1334ef577a83186220ffd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arne=20K=C3=BCderle?= Date: Fri, 26 Jul 2024 18:54:25 +0200 Subject: [PATCH] Made it possible to add default aggs to doc --- docs/modules/validate.rst | 23 +++++++++++++++++++++-- tpcp/validate/_scorer.py | 22 ++++++++++++++++------ 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/docs/modules/validate.rst b/docs/modules/validate.rst index 60a466a..6bec348 100644 --- a/docs/modules/validate.rst +++ b/docs/modules/validate.rst @@ -15,10 +15,29 @@ Classes :template: class_with_private.rst DatasetSplitter + +Scoring +------- +.. currentmodule:: tpcp.validate + + +.. autosummary:: + :toctree: generated/validate + :template: class_with_private.rst + Scorer Aggregator - MeanAggregator - NoAgg + FloatAggregator + +.. currentmodule:: tpcp.validate + +.. autosummary:: + :toctree: generated/validate + :template: function.rst + + mean_agg + no_agg + Functions --------- diff --git a/tpcp/validate/_scorer.py b/tpcp/validate/_scorer.py index dc36bcd..cfdd8e1 100644 --- a/tpcp/validate/_scorer.py +++ b/tpcp/validate/_scorer.py @@ -158,10 +158,6 @@ def aggregate(self, /, values: Sequence[float], datapoints: Sequence[Dataset]) - return float(vals) -mean_agg = FloatAggregator(np.mean) -mean_agg.__doc__ = """Aggregator that calculates the mean of the values.""" - - class _NoAgg(Aggregator[Any]): """Wrapper to wrap one or multiple output values of a scorer to prevent aggregation of these values. @@ -175,7 +171,7 @@ class _NoAgg(Aggregator[Any]): -------- >>> def score_func(pipe, dataset): ... ... - ... return {"score_val_1": score, "some_metadata": NoAgg(metadata)} + ... return {"score_val_1": score, "some_metadata": no_agg(metadata)} >>> my_scorer = Scorer(score_func) """ @@ -185,9 +181,23 @@ def aggregate(self, /, values: Sequence[Any], datapoints: Sequence[Dataset]) -> return NOTHING -no_agg = _NoAgg() +# We wrap the existing aggregators in functions, so that we can properly document them. +_no_agg = _NoAgg() + + +def no_agg(value: Any) -> _NoAgg: + return _no_agg(value) + + no_agg.__doc__ = _NoAgg.__doc__ +_mean_agg = FloatAggregator(np.mean) + + +def mean_agg(value: float) -> FloatAggregator: + """Calculate the mean of the values.""" + return _mean_agg(value) + class Scorer(Generic[PipelineT, DatasetT, T], BaseTpcpObject): """A scorer to score multiple data points of a dataset and average the results.