Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/hist fc start stride #2560

Merged
merged 11 commits into from
Nov 2, 2024
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co

**Improved**

- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader).

**Fixed**

- Fixed a bug when using `darts.utils.data.tabularization.create_lagged_component_names()` with target `lags=None`, that did not return any lagged target label component names. [#2576](https://github.com/unit8co/darts/pull/2576) by [Dennis Bader](https://github.com/dennisbader).
Expand Down
29 changes: 20 additions & 9 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,11 +706,12 @@ def historical_forecasts(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise

Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down Expand Up @@ -1018,6 +1019,7 @@ def retrain_func(
historical_forecasts_time_index=historical_forecasts_time_index,
start=start,
start_format=start_format,
stride=stride,
show_warnings=show_warnings,
)

Expand Down Expand Up @@ -1267,9 +1269,12 @@ def backtest(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise

Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down Expand Up @@ -1628,9 +1633,12 @@ def gridsearch(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise

Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Only used in expanding window mode. Defines the `start` format. Only effective when `start` is an integer
and `series` is indexed with a `pd.RangeIndex`.
Expand Down Expand Up @@ -1924,9 +1932,12 @@ def residuals(
or `retrain` is a Callable and the first trainable point is earlier than the first predictable point.
- the first trainable point (given `train_length`) otherwise

Note: Raises a ValueError if `start` yields a time outside the time index of `series`.
Note: If `start` is not within the trainable / forecastable points, uses the closest valid start point that
is a round multiple of `stride` ahead of `start`. Raises a `ValueError`, if no valid start point exists.
Note: If the model uses a shifted output (`output_chunk_shift > 0`), then the first predicted point is also
shifted by `output_chunk_shift` points into the future.
Note: If `start` is outside the possible historical forecasting times, will ignore the parameter
(default behavior with ``None``) and start at the first trainable/predictable point.
(default behavior with ``None``) and start at the first trainable/predictable point.
start_format
Defines the `start` format. Only effective when `start` is an integer and `series` is indexed with a
`pd.RangeIndex`.
Expand Down
28 changes: 19 additions & 9 deletions darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import logging
import random
from itertools import product

Expand Down Expand Up @@ -733,8 +734,7 @@ def test_backtest_multiple_series(self):
assert round(abs(error[0] - expected[0]), 4) == 0
assert round(abs(error[1] - expected[1]), 4) == 0

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_backtest_regression(self):
def test_backtest_regression(self, caplog):
np.random.seed(4)

gaussian_series = gt(mean=2, length=50)
Expand Down Expand Up @@ -804,13 +804,26 @@ def test_backtest_regression(self):
assert score > 0.9

# Using a too small start value
with pytest.raises(ValueError):
RandomForest(lags=12).backtest(series=target, start=0, forecast_horizon=3)
warning_expected = (
"`start` position `{0}` corresponding to time `{1}` is before the first "
"predictable/trainable historical forecasting point for series at index: 0. Using the first historical "
"forecasting point `2000-01-15 00:00:00` that lies a round-multiple of `stride=1` ahead of `start`. "
"To hide these warnings, set `show_warnings=False`."
)
caplog.clear()
with caplog.at_level(logging.WARNING):
_ = RandomForest(lags=12).backtest(
series=target, start=0, forecast_horizon=3
)
assert warning_expected.format(0, target.start_time()) in caplog.text
caplog.clear()

with pytest.raises(ValueError):
RandomForest(lags=12).backtest(
with caplog.at_level(logging.WARNING):
_ = RandomForest(lags=12).backtest(
series=target, start=0.01, forecast_horizon=3
)
assert warning_expected.format(0.01, target.start_time()) in caplog.text
caplog.clear()

# Using RandomForest's start default value
score = RandomForest(lags=12, random_state=0).backtest(
Expand Down Expand Up @@ -939,7 +952,6 @@ def test_gridsearch_metric_score(self):

assert score == recalculated_score, "The metric scores should match"

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_random_search(self):
np.random.seed(1)

Expand All @@ -958,7 +970,6 @@ def test_gridsearch_random_search(self):
assert isinstance(result[2], float)
assert min(param_range) <= result[1]["lags"] <= max(param_range)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_n_random_samples_bad_arguments(self):
dummy_series = get_dummy_series(ts_length=50)

Expand All @@ -981,7 +992,6 @@ def test_gridsearch_n_random_samples_bad_arguments(self):
params, dummy_series, forecast_horizon=1, n_random_samples=1.5
)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_gridsearch_n_random_samples(self):
np.random.seed(1)

Expand Down
Loading
Loading