Skip to content

Commit

Permalink
fix: missing test and small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Dec 20, 2024
1 parent 69fbc47 commit 911b8c5
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2092,6 +2092,44 @@ def test_lagged_training_data_invalid_output_chunk_length_error(self):
)
assert "`output_chunk_length` must be a positive `int`." == str(err.value)

def test_lagged_training_data_invalid_stride_error(self):
"""
Tests that `create_lagged_training_data` throws correct error
when `stride` is set to a non-`int` value (e.g. a
`float`) or a non-positive value (e.g. `0`).
"""
target = linear_timeseries(start=1, length=20, freq=1)
lags = [-1]
ocl = 2
# Check error thrown by 'moving windows' method and by 'time intersection' method:
for use_moving_windows in (False, True):
with pytest.raises(ValueError) as err:
create_lagged_training_data(
target_series=target,
output_chunk_length=ocl,
lags=lags,
uses_static_covariates=False,
use_moving_windows=use_moving_windows,
output_chunk_shift=0,
stride=-1,
)
assert "`stride` must be a positive integer greater than 0." == str(
err.value
)
with pytest.raises(ValueError) as err:
create_lagged_training_data(
target_series=target,
output_chunk_length=ocl,
lags=lags,
uses_static_covariates=False,
use_moving_windows=use_moving_windows,
output_chunk_shift=0,
stride=1.1,
)
assert "`stride` must be a positive integer greater than 0." == str(
err.value
)

def test_lagged_training_data_no_lags_specified_error(self):
"""
Tests that `create_lagged_training_data` throws correct error
Expand Down
8 changes: 4 additions & 4 deletions darts/utils/data/tabularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1022,9 +1022,9 @@ def _create_lagged_data_by_moving_window(
raise_log(
ValueError("Must specify at least one series-lags pair."), logger=logger
)
if not isinstance(stride, int) and stride < 1:
if not (isinstance(stride, int) and stride > 0):
raise_log(
ValueError("`stride` must be a positive integer greater than 1."),
ValueError("`stride` must be a positive integer greater than 0."),
logger=logger,
)
sample_weight_vals = _extract_sample_weight(sample_weight, target_series)
Expand Down Expand Up @@ -1270,9 +1270,9 @@ def _create_lagged_data_by_intersecting_times(
raise_log(
ValueError("Must specify at least one series-lags pair."), logger=logger
)
if not isinstance(stride, int) and stride < 1:
if not (isinstance(stride, int) and stride > 0):
raise_log(
ValueError("`stride` must be a positive integer greater than 1."),
ValueError("`stride` must be a positive integer greater than 0."),
logger=logger,
)
sample_weight_vals = _extract_sample_weight(sample_weight, target_series)
Expand Down

0 comments on commit 911b8c5

Please sign in to comment.