Skip to content

Commit

Permalink
Made it possible to add default aggs to doc
Browse files Browse the repository at this point in the history
  • Loading branch information
AKuederle committed Jul 26, 2024
1 parent f943385 commit 12a9b5e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
23 changes: 21 additions & 2 deletions docs/modules/validate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,29 @@ Classes
:template: class_with_private.rst

DatasetSplitter

Scoring
-------
.. currentmodule:: tpcp.validate


.. autosummary::
:toctree: generated/validate
:template: class_with_private.rst

Scorer
Aggregator
MeanAggregator
NoAgg
FloatAggregator

.. currentmodule:: tpcp.validate

.. autosummary::
:toctree: generated/validate
:template: function.rst

mean_agg
no_agg


Functions
---------
Expand Down
22 changes: 16 additions & 6 deletions tpcp/validate/_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,6 @@ def aggregate(self, /, values: Sequence[float], datapoints: Sequence[Dataset]) -
return float(vals)


mean_agg = FloatAggregator(np.mean)
mean_agg.__doc__ = """Aggregator that calculates the mean of the values."""


class _NoAgg(Aggregator[Any]):
"""Wrapper to wrap one or multiple output values of a scorer to prevent aggregation of these values.
Expand All @@ -175,7 +171,7 @@ class _NoAgg(Aggregator[Any]):
--------
>>> def score_func(pipe, dataset):
... ...
... return {"score_val_1": score, "some_metadata": NoAgg(metadata)}
... return {"score_val_1": score, "some_metadata": no_agg(metadata)}
>>> my_scorer = Scorer(score_func)
"""
Expand All @@ -185,9 +181,23 @@ def aggregate(self, /, values: Sequence[Any], datapoints: Sequence[Dataset]) ->
return NOTHING


no_agg = _NoAgg()
# We wrap the existing aggregators in functions, so that we can properly document them.
_no_agg = _NoAgg()


def no_agg(value: Any) -> _NoAgg:
return _no_agg(value)


no_agg.__doc__ = _NoAgg.__doc__

_mean_agg = FloatAggregator(np.mean)


def mean_agg(value: float) -> FloatAggregator:
"""Calculate the mean of the values."""
return _mean_agg(value)


class Scorer(Generic[PipelineT, DatasetT, T], BaseTpcpObject):
"""A scorer to score multiple data points of a dataset and average the results.
Expand Down

0 comments on commit 12a9b5e

Please sign in to comment.