diff --git a/darts/tests/utils/tabularization/test_create_lagged_training_data.py b/darts/tests/utils/tabularization/test_create_lagged_training_data.py index b1efa3ea0e..731b15c271 100644 --- a/darts/tests/utils/tabularization/test_create_lagged_training_data.py +++ b/darts/tests/utils/tabularization/test_create_lagged_training_data.py @@ -2092,6 +2092,44 @@ def test_lagged_training_data_invalid_output_chunk_length_error(self): ) assert "`output_chunk_length` must be a positive `int`." == str(err.value) + def test_lagged_training_data_invalid_stride_error(self): + """ + Tests that `create_lagged_training_data` throws correct error + when `stride` is set to a non-`int` value (e.g. a + `float`) or a non-positive value (e.g. `0`). + """ + target = linear_timeseries(start=1, length=20, freq=1) + lags = [-1] + ocl = 2 + # Check error thrown by 'moving windows' method and by 'time intersection' method: + for use_moving_windows in (False, True): + with pytest.raises(ValueError) as err: + create_lagged_training_data( + target_series=target, + output_chunk_length=ocl, + lags=lags, + uses_static_covariates=False, + use_moving_windows=use_moving_windows, + output_chunk_shift=0, + stride=-1, + ) + assert "`stride` must be a positive integer greater than 0." == str( + err.value + ) + with pytest.raises(ValueError) as err: + create_lagged_training_data( + target_series=target, + output_chunk_length=ocl, + lags=lags, + uses_static_covariates=False, + use_moving_windows=use_moving_windows, + output_chunk_shift=0, + stride=1.1, + ) + assert "`stride` must be a positive integer greater than 0." == str( + err.value + ) + def test_lagged_training_data_no_lags_specified_error(self): """ Tests that `create_lagged_training_data` throws correct error diff --git a/darts/utils/data/tabularization.py b/darts/utils/data/tabularization.py index 26f57bc433..5e029fcdc3 100644 --- a/darts/utils/data/tabularization.py +++ b/darts/utils/data/tabularization.py @@ -1022,9 +1022,9 @@ def _create_lagged_data_by_moving_window( raise_log( ValueError("Must specify at least one series-lags pair."), logger=logger ) - if not isinstance(stride, int) and stride < 1: + if not (isinstance(stride, int) and stride > 0): raise_log( - ValueError("`stride` must be a positive integer greater than 1."), + ValueError("`stride` must be a positive integer greater than 0."), logger=logger, ) sample_weight_vals = _extract_sample_weight(sample_weight, target_series) @@ -1270,9 +1270,9 @@ def _create_lagged_data_by_intersecting_times( raise_log( ValueError("Must specify at least one series-lags pair."), logger=logger ) - if not isinstance(stride, int) and stride < 1: + if not (isinstance(stride, int) and stride > 0): raise_log( - ValueError("`stride` must be a positive integer greater than 1."), + ValueError("`stride` must be a positive integer greater than 0."), logger=logger, ) sample_weight_vals = _extract_sample_weight(sample_weight, target_series)