Skip to content

Commit

Permalink
feat: added fit/predict_kwargs to historical_forecasts, backtest and …
Browse files Browse the repository at this point in the history
…gridsearch
  • Loading branch information
madtoinou committed Nov 3, 2023
1 parent aa8e341 commit a797b13
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 46 deletions.
108 changes: 83 additions & 25 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,8 @@ def historical_forecasts(
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
enable_optimization: bool = True,
num_loader_workers: int = 0,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down Expand Up @@ -696,9 +697,12 @@ def historical_forecasts(
Default: ``False``
enable_optimization
Whether to use the optimized version of historical_forecasts when supported and available.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
fit_kwargs
Additional arguments passed to the model `fit()` method, for example `max_samples_per_ts`,
`n_jobs_multiouput_wrapper` or `num_loader_workers`.
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `num_samples`,
`predict_likelihood_parameters` or `num_loader_workers`.
Returns
-------
Expand Down Expand Up @@ -809,6 +813,55 @@ def retrain_func(
logger,
)

if fit_kwargs is None:
fit_kwargs = dict()
if predict_kwargs is None:
predict_kwargs = dict()

# sanity checks of the arguments directly exposed by historical_forecasts
if "predict_likelihood_parameters" not in predict_kwargs:
predict_kwargs[
"predict_likelihood_parameters"
] = predict_likelihood_parameters
elif (
predict_kwargs["predict_likelihood_parameters"]
!= predict_likelihood_parameters
):
logger.warning(
"`predict_likelihood_parameters` was provided with contradictory values, "
"retaining the value passed with `predict_kwargs`."
)
if "num_samples" not in predict_kwargs:
predict_kwargs["num_samples"] = num_samples
elif predict_kwargs["num_samples"] != num_samples:
logger.warning(
"`num_samples` was provided with contradictory values, "
"retaining the value passed with `predict_kwargs`."
)

# fit/predict_kwargs cannot be used to pass arguments used by historical_forecast logic
forbiden_args = ["series", "past_covariates", "future_covariates"]
fit_invalid_args = forbiden_args + [
"val_series",
"val_past_covariates",
"val_future_covariates",
]
fit_invalid_args = set(fit_invalid_args).intersection(set(fit_kwargs.keys()))
if len(fit_invalid_args) > 0:
raise_log(
f"`fit_kwargs` cannot contain the following parameters : {list(fit_invalid_args)}.",
logger,
)
predict_invalid_args = forbiden_args + ["n", "trainer"]
predict_invalid_args = set(predict_invalid_args).intersection(
set(predict_kwargs.keys())
)
if len(predict_invalid_args) > 0:
raise_log(
f"`predict_kwargs` cannot contain the following parameters : {list(predict_invalid_args)}.",
logger,
)

series = series2seq(series)
past_covariates = series2seq(past_covariates)
future_covariates = series2seq(future_covariates)
Expand All @@ -826,7 +879,6 @@ def retrain_func(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
Expand All @@ -835,8 +887,7 @@ def retrain_func(
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
num_loader_workers=num_loader_workers,
predict_kwargs=predict_kwargs,
)

if len(series) == 1:
Expand Down Expand Up @@ -977,6 +1028,7 @@ def retrain_func(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
**fit_kwargs,
)
else:
# untrained model was not trained on the first trainable timestamp
Expand Down Expand Up @@ -1024,10 +1076,8 @@ def retrain_func(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
num_samples=num_samples,
verbose=verbose,
predict_likelihood_parameters=predict_likelihood_parameters,
num_loader_workers=num_loader_workers,
**predict_kwargs,
)
if forecast_components is None:
forecast_components = forecast.columns
Expand Down Expand Up @@ -1085,7 +1135,8 @@ def backtest(
reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
verbose: bool = False,
show_warnings: bool = True,
num_loader_workers: int = 0,
fit_kwargs: Optional[Dict[str, Any]] = None,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[float, List[float], Sequence[float], List[Sequence[float]]]:
"""Compute error values that the model would have produced when
used on (potentially multiple) `series`.
Expand Down Expand Up @@ -1195,9 +1246,12 @@ def backtest(
Whether to print progress.
show_warnings
Whether to show warnings related to parameters `start`, and `train_length`.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
fit_kwargs
Additional arguments passed to the model `fit()` method, for example `max_samples_per_ts`,
`n_jobs_multiouput_wrapper` or `num_loader_workers`.
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `num_samples`,
`predict_likelihood_parameters` or `num_loader_workers`.
Returns
-------
Expand All @@ -1221,7 +1275,8 @@ def backtest(
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
num_loader_workers=num_loader_workers,
fit_kwargs=fit_kwargs,
predict_kwargs=predict_kwargs,
)
else:
forecasts = historical_forecasts
Expand Down Expand Up @@ -1275,7 +1330,7 @@ def gridsearch(
verbose=False,
n_jobs: int = 1,
n_random_samples: Optional[Union[int, float]] = None,
num_loader_workers: int = 0,
predict_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple["ForecastingModel", Dict[str, Any], float]:
"""
Find the best hyper-parameters among a given set using a grid search.
Expand Down Expand Up @@ -1389,9 +1444,9 @@ def gridsearch(
must be between `0` and the total number of parameter combinations.
If a float, `n_random_samples` is the ratio of parameter combinations selected from the full grid and must
be between `0` and `1`. Defaults to `None`, for which random selection will be ignored.
num_loader_workers
Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances,
for the inference/prediction dataset loaders (if any).
predict_kwargs
Additional arguments passed to the model `predict()` method, for example `predict_likelihood_parameters` or
`num_loader_workers`.
Returns
-------
Expand Down Expand Up @@ -1424,10 +1479,13 @@ def gridsearch(
logger,
)

# TODO: here too I'd say we can leave these checks to the models
# if covariates is not None:
# raise_if_not(series.has_same_time_as(covariates), 'The provided series and covariates must have the '
# 'same time axes.')
if predict_kwargs is None:
predict_kwargs = dict()
raise_if(
"num_samples" in predict_kwargs,
"`num_samples = 1` cannot be modified using `predict_kwargs`.",
logger,
)

# compute all hyperparameter combinations from selection
params_cross_product = list(product(*parameters.values()))
Expand Down Expand Up @@ -1475,7 +1533,7 @@ def _evaluate_combination(param_combination) -> float:
last_points_only=last_points_only,
verbose=verbose,
show_warnings=show_warnings,
num_loader_workers=num_loader_workers,
predict_kwargs=predict_kwargs,
)
else: # split mode
model._fit_wrapper(series, past_covariates, future_covariates)
Expand All @@ -1486,7 +1544,7 @@ def _evaluate_combination(param_combination) -> float:
future_covariates,
num_samples=1,
verbose=verbose,
num_loader_workers=num_loader_workers,
**predict_kwargs,
)
error = metric(val_series, pred)

Expand Down
10 changes: 3 additions & 7 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,6 @@ def _optimized_historical_forecasts(
series: Optional[Sequence[TimeSeries]],
past_covariates: Optional[Sequence[TimeSeries]] = None,
future_covariates: Optional[Sequence[TimeSeries]] = None,
num_samples: int = 1,
start: Optional[Union[pd.Timestamp, float, int]] = None,
start_format: Literal["position", "value"] = "value",
forecast_horizon: int = 1,
Expand All @@ -1102,8 +1101,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
verbose: bool = False,
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
predict_kwargs: Dict[str, Any] = {},
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down Expand Up @@ -1131,29 +1129,27 @@ def _optimized_historical_forecasts(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
stride=stride,
overlap_end=overlap_end,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)
else:
return _optimized_historical_forecasts_all_points(
model=self,
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
stride=stride,
overlap_end=overlap_end,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)


Expand Down
8 changes: 2 additions & 6 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2027,7 +2027,6 @@ def _optimized_historical_forecasts(
series: Optional[Sequence[TimeSeries]],
past_covariates: Optional[Sequence[TimeSeries]] = None,
future_covariates: Optional[Sequence[TimeSeries]] = None,
num_samples: int = 1,
start: Optional[Union[pd.Timestamp, float, int]] = None,
start_format: Literal["position", "value"] = "value",
forecast_horizon: int = 1,
Expand All @@ -2036,8 +2035,7 @@ def _optimized_historical_forecasts(
last_points_only: bool = True,
verbose: bool = False,
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
num_loader_workers: int = 0,
predict_kwargs: Dict[str, Any] = {},
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand All @@ -2058,17 +2056,15 @@ def _optimized_historical_forecasts(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=num_samples,
start=start,
start_format=start_format,
forecast_horizon=forecast_horizon,
stride=stride,
overlap_end=overlap_end,
last_points_only=last_points_only,
show_warnings=show_warnings,
predict_likelihood_parameters=predict_likelihood_parameters,
verbose=verbose,
num_loader_workers=num_loader_workers,
predict_kwargs=predict_kwargs,
)
return forecasts_list

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Sequence, Union
from typing import Any, Dict, List, Optional, Sequence, Union

try:
from typing import Literal
Expand All @@ -24,17 +24,15 @@ def _optimized_historical_forecasts(
series: Sequence[TimeSeries],
past_covariates: Optional[Sequence[TimeSeries]] = None,
future_covariates: Optional[Sequence[TimeSeries]] = None,
num_samples: int = 1,
start: Optional[Union[pd.Timestamp, float, int]] = None,
start_format: Literal["position", "value"] = "value",
forecast_horizon: int = 1,
stride: int = 1,
overlap_end: bool = False,
last_points_only: bool = True,
show_warnings: bool = True,
predict_likelihood_parameters: bool = False,
verbose: bool = False,
num_loader_workers: int = 0,
predict_kwargs: Dict[str, Any] = {},
) -> Union[
TimeSeries, List[TimeSeries], Sequence[TimeSeries], Sequence[List[TimeSeries]]
]:
Expand Down Expand Up @@ -98,9 +96,7 @@ def _optimized_historical_forecasts(
series,
past_covariates,
future_covariates,
num_samples=num_samples,
predict_likelihood_parameters=predict_likelihood_parameters,
num_loader_workers=num_loader_workers,
**predict_kwargs,
)

dataset = model._build_inference_dataset(
Expand All @@ -117,7 +113,7 @@ def _optimized_historical_forecasts(
dataset,
trainer=None,
verbose=verbose,
predict_likelihood_parameters=predict_likelihood_parameters,
**predict_kwargs,
)

# torch models return list of time series in order of historical forecasts: we reorder per time series
Expand Down

0 comments on commit a797b13

Please sign in to comment.