Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🤖 Implement Chunked Processing and Retry Logic for Cleanup Tasks #1618

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 52 additions & 43 deletions src/seer/anomaly_detection/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,14 @@
from seer.anomaly_detection.models.external import AnomalyDetectionConfig
from seer.db import DbDynamicAlert, Session, TaskStatus
from seer.dependency_injection import inject, injected
from seer.anomaly_detection.utils import chunks, with_retry, safe_commit

logger = logging.getLogger(__name__)

CHUNK_SIZE = 100 # Process data in chunks of 100 records

@celery_app.task
@with_retry(max_retries=3)
@sentry_sdk.trace
def cleanup_timeseries(alert_id: int, date_threshold: float):
span = sentry_sdk.get_current_span()
Expand Down Expand Up @@ -45,23 +48,27 @@ def cleanup_timeseries(alert_id: int, date_threshold: float):
direction=alert.config["direction"],
expected_seasonality=alert.config["expected_seasonality"],
)
deleted_timeseries_points = delete_old_timeseries_points(alert, date_threshold)

# Delete old points in chunks
deleted_timeseries_points = 0
for chunk in chunks(list(alert.timeseries), CHUNK_SIZE):
deleted_count = delete_old_timeseries_points(chunk, date_threshold)
deleted_timeseries_points += deleted_count
safe_commit(session)

if len(alert.timeseries) > 0:
updated_timeseries_points = update_matrix_profiles(alert, config)
total_updated = 0
for chunk in chunks(list(alert.timeseries), CHUNK_SIZE):
updated_count = update_matrix_profiles_chunk(alert, chunk, config)
total_updated += updated_count
safe_commit(session)
else:
# Reset the window size to 0 if there are no timeseries points left
alert.anomaly_algo_data = {"window_size": 0}
logger.warn(f"Alert with id {alert_id} has empty timeseries data after pruning")
updated_timeseries_points = 0
session.commit()
logger.info(f"Deleted {deleted_timeseries_points} timeseries points")
logger.info(
f"Updated matrix profiles for {updated_timeseries_points} points in alertd id {alert_id}"
)

toggle_data_purge_flag(alert_id)



def delete_old_timeseries_points(alert: DbDynamicAlert, date_threshold: float):
deleted_count = 0
to_remove = []
Expand All @@ -73,41 +80,43 @@ def delete_old_timeseries_points(alert: DbDynamicAlert, date_threshold: float):
deleted_count += 1
return deleted_count

def update_matrix_profiles_chunk(
alert: DbDynamicAlert,
chunk: List[DbDynamicAlert],
anomaly_detection_config: AnomalyDetectionConfig,
algo_config: AlgoConfig = injected,
):
"""Process a chunk of timeseries data for matrix profile updates."""
try:
timeseries = TimeSeries(
timestamps=np.array([ts.timestamp.timestamp() for ts in chunk]),
values=np.array([ts.value for ts in chunk])
)

@inject
def update_matrix_profiles(
alert: DbDynamicAlert,
anomaly_detection_config: AnomalyDetectionConfig,
algo_config: AlgoConfig = injected,
):

timeseries = TimeSeries(
timestamps=np.array([timestep.timestamp.timestamp() for timestep in alert.timeseries]),
values=np.array([timestep.value for timestep in alert.timeseries]),
)

anomalies_suss = MPBatchAnomalyDetector()._compute_matrix_profile(
timeseries=timeseries, ad_config=anomaly_detection_config, algo_config=algo_config
)
anomalies_fixed = MPBatchAnomalyDetector()._compute_matrix_profile(
timeseries=timeseries,
ad_config=anomaly_detection_config,
algo_config=algo_config,
window_size=algo_config.mp_fixed_window_size,
)
anomalies = DbAlertDataAccessor().combine_anomalies(
anomalies_suss, anomalies_fixed, [True] * len(timeseries.timestamps)
)
detector = MPBatchAnomalyDetector()
anomalies_suss = detector._compute_matrix_profile(
timeseries=timeseries, ad_config=anomaly_detection_config, algo_config=algo_config
)
anomalies_fixed = detector._compute_matrix_profile(
timeseries=timeseries,
ad_config=anomaly_detection_config,
algo_config=algo_config,
window_size=algo_config.mp_fixed_window_size,
)

anomalies = DbAlertDataAccessor().combine_anomalies(
anomalies_suss, anomalies_fixed, [True] * len(timeseries.timestamps)
)

algo_data_map = dict(
zip(timeseries.timestamps, anomalies.get_anomaly_algo_data(len(timeseries.timestamps)))
)
updateed_timeseries_points = 0
for timestep in alert.timeseries:
timestep.anomaly_algo_data = algo_data_map[timestep.timestamp.timestamp()]
updateed_timeseries_points += 1
alert.anomaly_algo_data = {"window_size": anomalies.window_size}
return updateed_timeseries_points
algo_data = anomalies.get_anomaly_algo_data(len(timeseries.timestamps))
for ts, data in zip(chunk, algo_data):
ts.anomaly_algo_data = data
alert.anomaly_algo_data = {"window_size": anomalies.window_size}
return len(chunk)
except Exception as e:
logger.error(f"Error processing chunk: {str(e)}")
raise


def toggle_data_purge_flag(alert_id: int):
Expand Down
38 changes: 38 additions & 0 deletions src/seer/anomaly_detection/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import time
from functools import wraps
from typing import Iterator, List, TypeVar, Any
from sqlalchemy.exc import OperationalError

T = TypeVar('T')

def chunks(lst: List[T], n: int) -> Iterator[List[T]]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

def with_retry(max_retries: int = 3, backoff_base: int = 2):
"""Decorator that implements retry logic with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
last_exception = None
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except OperationalError as e:
last_exception = e
if attempt < max_retries - 1:
sleep_time = backoff_base ** attempt
time.sleep(sleep_time)
continue
raise last_exception
return wrapper
return decorator

def safe_commit(session):
"""Safely commit database changes with proper error handling."""
try:
session.commit()
except Exception:
session.rollback()
raise
83 changes: 82 additions & 1 deletion tests/seer/anomaly_detection/test_cleanup_tasks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
import random
import unittest
from datetime import datetime, timedelta
from unittest.mock import patch, MagicMock

import numpy as np


from seer.anomaly_detection.accessors import DbAlertDataAccessor
from seer.anomaly_detection.detectors.anomaly_detectors import MPBatchAnomalyDetector
from seer.anomaly_detection.models import MPTimeSeriesAnomalies
from seer.anomaly_detection.models.external import AnomalyDetectionConfig, TimeSeriesPoint
from seer.anomaly_detection.models.timeseries import TimeSeries
from seer.anomaly_detection.tasks import cleanup_disabled_alerts, cleanup_timeseries
from seer.anomaly_detection.tasks import (cleanup_disabled_alerts, cleanup_timeseries,
CHUNK_SIZE)
from sqlalchemy.exc import OperationalError
from seer.db import DbDynamicAlert, DbDynamicAlertTimeSeries, Session, TaskStatus




class TestCleanupTasks(unittest.TestCase):
def _save_alert(self, external_alert_id: int, num_old_points: int, num_new_points: int):
# Helper function to save an alert with a given number of old and new points
Expand Down Expand Up @@ -234,5 +240,80 @@ def test_cleanup_disabled_alerts(self):
.one_or_none()
)


assert alert is not None
assert len(alert.timeseries) == 500

def test_large_batch_update_with_connection_issues(self):
"""
Test to verify that the chunked processing handles database connection
issues gracefully when processing large batches of data.
"""
# Create a large dataset that would previously cause connection issues
NUM_OLD_POINTS = 1000
NUM_NEW_POINTS = 2000
external_alert_id, config, points, _ = self._save_alert(
999, NUM_OLD_POINTS, NUM_NEW_POINTS
)

# Verify initial data
with Session() as session:
alert = (
session.query(DbDynamicAlert)
.filter(DbDynamicAlert.external_alert_id == external_alert_id)
.one_or_none()
)
self.assertIsNotNone(alert)
self.assertEqual(len(alert.timeseries), NUM_OLD_POINTS + NUM_NEW_POINTS)

# Mock the Session commit to simulate connection issues
original_commit = Session.commit
commit_count = 0
failed_commits = set()

def mock_commit(session_self):
nonlocal commit_count
commit_count += 1

# Simulate connection issues for specific commits
if commit_count % 3 == 0 and commit_count not in failed_commits:
failed_commits.add(commit_count)
raise OperationalError(
"consuming input failed: server closed the connection unexpectedly",
None,
None
)
return original_commit(session_self)

date_threshold = (datetime.now() - timedelta(days=28)).timestamp()

# Patch the commit method and run the cleanup
with patch.object(Session, 'commit', mock_commit):
# Should not raise any exceptions due to retry mechanism
cleanup_timeseries(external_alert_id, date_threshold)

# Verify the results
with Session() as session:
alert = (
session.query(DbDynamicAlert)
.filter(DbDynamicAlert.external_alert_id == external_alert_id)
.one_or_none()
)
self.assertIsNotNone(alert)

# Verify only new points remain
self.assertEqual(len(alert.timeseries), NUM_NEW_POINTS)

# Verify all remaining points are newer than threshold
for ts in alert.timeseries:
self.assertTrue(ts.timestamp.timestamp() >= date_threshold)

# Verify all points have properly computed algo_data
for ts in alert.timeseries:
self.assertIsNotNone(ts.anomaly_algo_data)
self.assertTrue("mp_suss" in ts.anomaly_algo_data or "mp_fixed" in ts.anomaly_algo_data)

# Verify chunked processing by checking final stats
expected_chunks = -(-NUM_NEW_POINTS // CHUNK_SIZE) # Ceiling division
self.assertTrue(commit_count >= expected_chunks)
self.assertTrue(len(failed_commits) > 0) # Verify some commits actually failed
Loading