Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT] stop segmenters changing state in predict #2526

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions aeon/segmentation/_fluss.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class FLUSSSegmenter(BaseSegmenter):
"""

_tags = {
"fit_is_empty": True,
"fit_is_empty": False,
"python_dependencies": "stumpy",
}

Expand All @@ -54,6 +54,15 @@ def __init__(self, period_length=10, n_regimes=2, exclusion_factor=5):
self.exclusion_factor = exclusion_factor
super().__init__(n_segments=n_regimes, axis=1)

def _fit(self, X, y=None):
if self.n_regimes < 2:
raise ValueError(
"The number of regimes must be set to an integer greater than 1"
)

X = X.squeeze()
self.found_cps, self.profiles, self.scores = self._run_fluss(X)

def _predict(self, X: np.ndarray):
"""Create annotations on test/deployment data.

Expand All @@ -67,13 +76,6 @@ def _predict(self, X: np.ndarray):
list
List of change points found in X.
"""
if self.n_regimes < 2:
raise ValueError(
"The number of regimes must be set to an integer greater than 1"
)

X = X.squeeze()
self.found_cps, self.profiles, self.scores = self._run_fluss(X)
return self.found_cps

def predict_scores(self, X):
Expand Down
6 changes: 2 additions & 4 deletions aeon/segmentation/_ggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,10 @@ def _predict(self, X):
dimension of X. The numerical values represent distinct segments
labels for each of the data points.
"""
self.change_points_ = self.ggs.find_change_points(X)
change_points_ = self.ggs.find_change_points(X)

labels = np.zeros(X.shape[0], dtype=np.int32)
for i, (start, stop) in enumerate(
zip(self.change_points_[:-1], self.change_points_[1:])
):
for i, (start, stop) in enumerate(zip(change_points_[:-1], change_points_[1:])):
labels[start:stop] = i
return labels

Expand Down
17 changes: 6 additions & 11 deletions aeon/segmentation/_igts.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class _IGTS:
"""
Information Gain based Temporal Segmentation (GTS).

GTS is a n unsupervised method for segmenting multivariate time series
GTS is an unsupervised method for segmenting multivariate time series
into non-overlapping segments by locating change points that for which
the information gain is maximized.

Expand All @@ -127,18 +127,13 @@ class _IGTS:

Parameters
----------
k_max: int, default=10
k_max : int, default=10
Maximum number of change points to find. The number of segments is thus k+1.
step: : int, default=5
step : int, default=5
Step size, or stride for selecting candidate locations of change points.
Fox example a `step=5` would produce candidates [0, 5, 10, ...]. Has the same
meaning as `step` in `range` function.

Attributes
----------
intermediate_results_: list of `ChangePointResult`
Intermediate segmentation results for each k value, where k=1, 2, ..., k_max

Notes
-----
Based on the work from [1]_.
Expand Down Expand Up @@ -366,9 +361,9 @@ def _predict(self, X, y=None) -> np.ndarray:
The numerical values represent distinct segment labels for each of the
data points.
"""
self.change_points_ = self._igts.find_change_points(X)
self.intermediate_results_ = self._igts.intermediate_results_
return self.to_clusters(self.change_points_[1:-1], X.shape[0])
change_points_ = self._igts.find_change_points(X)
# self.intermediate_results_ = self._igts.intermediate_results_
return self.to_clusters(change_points_[1:-1], X.shape[0])

def __repr__(self) -> str:
"""Return a string representation of the estimator."""
Expand Down
2 changes: 0 additions & 2 deletions aeon/segmentation/tests/test_igts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,3 @@ def test_InformationGainSegmenter(multivariate_mean_shift):
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3], dtype=int
),
)
assert igts.change_points_ == [0, 5, 10, 15, 20]
assert len(igts.intermediate_results_) == 3
5 changes: 1 addition & 4 deletions aeon/testing/testing_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,9 @@
"SAST": ["check_fit_deterministic"],
"RSAST": ["check_fit_deterministic"],
# missed in legacy testing, changes state in predict/transform
"FLUSSSegmenter": ["check_non_state_changing_method"],
"InformationGainSegmenter": ["check_non_state_changing_method"],
"GreedyGaussianSegmenter": ["check_non_state_changing_method"],
"ClaSPSegmenter": ["check_non_state_changing_method"],
"HMMSegmenter": ["check_non_state_changing_method"],
"RSTSF": ["check_non_state_changing_method"],
"ClaSPSegmenter": ["check_non_state_changing_method"],
# Keeps length during predict to avoid recomputing means and std of data in fit
# if the next predict calls uses the same query length parameter.
"QuerySearch": ["check_non_state_changing_method"],
Expand Down
Loading