Skip to content

Commit

Permalink
Change "multivariate_plot" parameter name to "component_wise"
Browse files Browse the repository at this point in the history
  • Loading branch information
cnhwl committed Jan 6, 2025
1 parent c6b799e commit 78baeec
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 23 deletions.
6 changes: 3 additions & 3 deletions darts/ad/anomaly_model/anomaly_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def show_anomalies(
names_of_scorers: Union[str, Sequence[str]] = None,
title: str = None,
metric: Optional[Literal["AUC_ROC", "AUC_PR"]] = None,
multivariate_plot: bool = False,
component_wise: bool = False,
**score_kwargs,
):
"""Plot the results of the anomaly model.
Expand Down Expand Up @@ -284,7 +284,7 @@ def show_anomalies(
Default: "AUC_ROC".
score_kwargs
parameters for the `score()` method.
multivariate_plot
component_wise
If True, it will separately plot each component in multivariate series.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
Expand Down Expand Up @@ -313,7 +313,7 @@ def show_anomalies(
names_of_scorers=names_of_scorers,
title=title,
metric=metric,
multivariate_plot=multivariate_plot,
component_wise=component_wise,
)

@property
Expand Down
6 changes: 3 additions & 3 deletions darts/ad/anomaly_model/forecasting_am.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def show_anomalies(
names_of_scorers: Union[str, Sequence[str]] = None,
title: str = None,
metric: Optional[Literal["AUC_ROC", "AUC_PR"]] = None,
multivariate_plot: bool = False,
component_wise: bool = False,
**score_kwargs,
):
"""Plot the results of the anomaly model.
Expand Down Expand Up @@ -507,7 +507,7 @@ def show_anomalies(
Optionally, the name of the metric function to use. Must be one of "AUC_ROC" (Area Under the
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
multivariate_plot
component_wise
If True, it will separately plot each component in multivariate series.
score_kwargs
parameters for the `score()` method.
Expand All @@ -530,7 +530,7 @@ def show_anomalies(
names_of_scorers=names_of_scorers,
title=title,
metric=metric,
multivariate_plot=multivariate_plot,
component_wise=component_wise,
**score_kwargs,
)

Expand Down
12 changes: 6 additions & 6 deletions darts/ad/scorers/scorers.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def show_anomalies_from_prediction(
anomalies: TimeSeries = None,
title: str = None,
metric: Optional[Literal["AUC_ROC", "AUC_PR"]] = None,
multivariate_plot: bool = False,
component_wise: bool = False,
):
"""Plot the results of the scorer.
Expand Down Expand Up @@ -209,7 +209,7 @@ def show_anomalies_from_prediction(
Optionally, the name of the metric function to use. Must be one of "AUC_ROC" (Area Under the
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
multivariate_plot
component_wise
If True, it will separately plot each component in multivariate series.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
Expand All @@ -233,7 +233,7 @@ def show_anomalies_from_prediction(
names_of_scorers=scorer_name,
title=title,
metric=metric,
multivariate_plot=multivariate_plot,
component_wise=component_wise,
)

@property
Expand Down Expand Up @@ -584,7 +584,7 @@ def show_anomalies(
scorer_name: str = None,
title: str = None,
metric: Optional[Literal["AUC_ROC", "AUC_PR"]] = None,
multivariate_plot: bool = False,
component_wise: bool = False,
):
"""Plot the results of the scorer.
Expand Down Expand Up @@ -615,7 +615,7 @@ def show_anomalies(
Optionally, the name of the metric function to use. Must be one of "AUC_ROC" (Area Under the
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Default: "AUC_ROC".
multivariate_plot
component_wise
If True, it will separately plot each component in multivariate series.
"""
series = _check_input(series, name="series", num_series_expected=1)[0]
Expand All @@ -640,7 +640,7 @@ def show_anomalies(
names_of_scorers=scorer_name,
title=title,
metric=metric,
multivariate_plot=multivariate_plot,
component_wise=component_wise,
)

@property
Expand Down
22 changes: 11 additions & 11 deletions darts/ad/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def show_anomalies_from_scores(
names_of_scorers: Union[str, Sequence[str]] = None,
title: str = None,
metric: Optional[Literal["AUC_ROC", "AUC_PR"]] = None,
multivariate_plot: bool = False,
component_wise: bool = False,
):
"""Plot the results generated by an anomaly model.
Expand Down Expand Up @@ -353,14 +353,14 @@ def show_anomalies_from_scores(
Receiver Operating Characteristic Curve) and "AUC_PR" (Average Precision from scores).
Only effective when `pred_scores` is not `None`.
Default: "AUC_ROC".
multivariate_plot
component_wise
If True, it will separately plot each component in multivariate series.
"""
series = _check_input(
series,
name="series",
num_series_expected=1,
check_multivariate=multivariate_plot,
check_multivariate=component_wise,
)[0]

if title is None and pred_scores is not None:
Expand Down Expand Up @@ -433,30 +433,30 @@ def show_anomalies_from_scores(
name="pred_series",
width_expected=series_width,
num_series_expected=1,
check_multivariate=multivariate_plot,
check_multivariate=component_wise,
)[0]

if anomalies is not None and multivariate_plot:
if anomalies is not None and component_wise:
anomalies = _check_input(
anomalies,
name="anomalies",
width_expected=series_width,
num_series_expected=1,
check_binary=True,
check_multivariate=multivariate_plot,
check_multivariate=component_wise,
)[0]

if pred_scores is not None and multivariate_plot:
if pred_scores is not None and component_wise:
for pred_score in pred_scores:
_ = _check_input(
pred_score,
name="pred_score",
width_expected=series_width,
num_series_expected=1,
check_multivariate=multivariate_plot,
check_multivariate=component_wise,
)[0]

plots_per_ts = nbr_plots * series_width if multivariate_plot else nbr_plots
plots_per_ts = nbr_plots * series_width if component_wise else nbr_plots
fig, axs = plt.subplots(
plots_per_ts,
figsize=(8, 4 * (plots_per_ts // nbr_plots) + 2 * (nbr_plots - 1)),
Expand All @@ -468,8 +468,8 @@ def show_anomalies_from_scores(
layout="constrained",
)

for i in range(series_width if multivariate_plot else 1):
if multivariate_plot:
for i in range(series_width if component_wise else 1):
if component_wise:
series_ = series[series.components[i]]
anomalies_ = (
anomalies[anomalies.components[i]] if anomalies is not None else None
Expand Down

0 comments on commit 78baeec

Please sign in to comment.