Skip to content

Commit

Permalink
Check for NaN stack data and update WCT jobs to "T" (#387)
Browse files Browse the repository at this point in the history
* Update s04_stack2.py

* Update api.py

* Update s04_stack2.py: wct update

* Update tests.py: adapt test wct

* Update tests.py: test validate_stack_data

* Update tests.py

* Update tests.py: back to previous test

* Update tests.py

* Update tests.py

* Update tests.py

* Update tests.py

* Update tests.py

* Update tests.py

* Update tests.py
  • Loading branch information
LaureBrenot authored Jan 10, 2025
1 parent ffd3c7e commit 06725f4
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 26 deletions.
32 changes: 31 additions & 1 deletion msnoise/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,7 +1683,37 @@ def build_ref_datelist(session):
datelist = pd.date_range(start, end).map(lambda x: x.date())
return start, end, datelist.tolist()


def validate_stack_data(dataset, stack_type="reference"):
"""Validates stack data before processing
Parameters:
dataset: xarray Dataset to validate
stack_type: Type of stack ("reference" or "moving") for error messages
Returns:
(is_valid, message) tuple
"""
if dataset is None or not dataset.data_vars:
return False, f"No data found for {stack_type} stack"

if not hasattr(dataset, 'CCF'):
return False, f"Missing CCF data in {stack_type} stack"

data = dataset.CCF
if data.size == 0:
return False, f"Empty dataset in {stack_type} stack"

nan_count = np.isnan(data.values).sum()
total_points = data.values.size

if nan_count == total_points:
return False, f"{stack_type.capitalize()} stack contains only NaN values"

if nan_count > 0:
percent_nan = (nan_count / total_points) * 100
return True, f"Warning: {stack_type.capitalize()} stack contains {percent_nan:.1f}% NaN values"

return True, "OK"

def build_movstack_datelist(session):
"""
Creates a date array for the analyse period.
Expand Down
20 changes: 19 additions & 1 deletion msnoise/s04_stack2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
The corresponding configuration bits are ``ref_begin`` and ``ref_end``. In the
future, we plan on allowing multiple references to be defined.
Each reference and moving stack is validated to ensure data presence and
completeness, checking for empty datasets and NaN content before processing.
Only data for new/modified dates need to be exported. If any CC-job has been
marked "Done" within the last day and triggered the creation of STACK jobs,
the stacks will be calculated and a new MWCS job will be inserted in the
Expand Down Expand Up @@ -89,7 +92,7 @@
Once done, each job is marked "D"one in the database and, unless ``hpc`` is
``Y``, MWCS jobs are inserted/updated in the database.
``Y``, MWCS and WCT jobs are inserted/updated in the database.
Usage:
~~~~~~
Expand Down Expand Up @@ -216,6 +219,13 @@ def main(stype, interval=1.0, loglevel="INFO"):
else:
c = get_results(db, sta1, sta2, filterid, components, datelist, mov_stack=1, format="xarray", params=params)

is_valid, message = validate_stack_data(c, "reference")
if not is_valid:
logger.error(f"Invalid reference data for {sta1}:{sta2}-{components}-{filterid}: {message}")
continue
elif "Warning" in message:
logger.warning(f"{sta1}:{sta2}-{components}-{filterid}: {message}")

# dr = xr_save_ccf(sta1, sta2, components, filterid, 1, taxis, c)
dr = c
if not c.data_vars:
Expand Down Expand Up @@ -328,6 +338,13 @@ def main(stype, interval=1.0, loglevel="INFO"):
params=params).sortby('times')
dr = c.resample(times="1D").mean()

is_valid, message = validate_stack_data(c, "moving")
if not is_valid:
logger.error(f"Invalid moving stack data for {sta1}:{sta2}-{components}-{filterid}: {message}")
continue
elif "Warning" in message:
logger.warning(f"{sta1}:{sta2}-{components}-{filterid}: {message}")

if wienerfilt:
dr = wiener_filt(dr, wiener_M, wiener_N, wiener_gap_threshold)

Expand Down Expand Up @@ -368,3 +385,4 @@ def main(stype, interval=1.0, loglevel="INFO"):
if stype != "step" and not params.hpc:
for job in jobs:
update_job(db, job.day, job.pair, 'MWCS', 'T')
update_job(db, job.day, job.pair, 'WCT', 'T')
120 changes: 96 additions & 24 deletions msnoise/test/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,31 +443,103 @@ def test_031_instrument_response(setup_environment):
@pytest.mark.order(32)
def test_032_wct():
from ..s08compute_wct import main as compute_wct_main
db = connect()
dbini = read_db_inifile()
prefix = (dbini.prefix + '_') if dbini.prefix != '' else ''
db.execute(text(
f"INSERT INTO {prefix}jobs (pair, day, jobtype, flag) "
f"SELECT pair, day, 'WCT', 'T' FROM {prefix}jobs "
f"WHERE jobtype='STACK' AND flag='D';"
))
db.commit()
compute_wct_main()
db.close()

@pytest.mark.order(100)
def test_100_plot_cctfime():
from ..plots.ccftime import main as ccftime_main
db = connect()
for sta1, sta2 in get_station_pairs(db):
for loc1 in sta1.locs():
for loc2 in sta2.locs():
for filter in get_filters(db):
ccftime_main("%s.%s.%s" % (sta1.net, sta1.sta, loc1), "%s.%s.%s" % (sta2.net, sta2.sta, loc2), filter.ref, "ZZ", 1, show=False, outfile="?.png")
fn = 'ccftime %s-%s-%s-f%i-m%s_%s.png' % ("%s.%s.%s" % (sta1.net, sta1.sta, loc1), "%s.%s.%s" % (sta2.net, sta2.sta, loc2),
"ZZ", filter.ref, "1d", "1d")
assert os.path.isfile(fn)


@pytest.mark.order(33)
def test_033_validate_stack_data():
from ..api import validate_stack_data
import xarray as xr
import numpy as np
import pandas as pd

# Test empty dataset
ds = xr.Dataset()
is_valid, message = validate_stack_data(ds, "reference")
assert not is_valid
assert "No data found for reference stack" in message

# Test dataset without CCF
ds = xr.Dataset({"wrong_var": 1})
is_valid, message = validate_stack_data(ds, "reference")
assert not is_valid
assert "Missing CCF data in reference stack" in message

# Test empty CCF data
times = pd.date_range('2020-01-01', periods=0)
taxis = np.linspace(-50, 50, 100)
data = np.random.random((0, len(taxis)))
da = xr.DataArray(data, coords=[times, taxis], dims=['times', 'taxis'])
ds = da.to_dataset(name='CCF')
is_valid, message = validate_stack_data(ds, "reference")
assert not is_valid
assert "Empty dataset in reference stack" in message

# Test all NaN values
times = pd.date_range('2020-01-01', periods=10)
data = np.full((len(times), len(taxis)), np.nan)
da = xr.DataArray(data, coords=[times, taxis], dims=['times', 'taxis'])
ds = da.to_dataset(name='CCF')
is_valid, message = validate_stack_data(ds, "reference")
assert not is_valid
assert "Reference stack contains only NaN values" in message

# Test partial NaN values
data = np.random.random((len(times), len(taxis)))
data[0:5, :] = np.nan
da = xr.DataArray(data, coords=[times, taxis], dims=['times', 'taxis'])
ds = da.to_dataset(name='CCF')
is_valid, message = validate_stack_data(ds, "reference")
assert is_valid
# We can still check if the warning message contains the percentage
assert "50.0% NaN values" in message

# Test valid data
data = np.random.random((len(times), len(taxis)))
da = xr.DataArray(data, coords=[times, taxis], dims=['times', 'taxis'])
ds = da.to_dataset(name='CCF')
is_valid, message = validate_stack_data(ds, "reference")
assert is_valid
assert message == "OK"

@pytest.mark.order(34)
def test_034_stack_validation_handling():
from ..api import validate_stack_data
import xarray as xr
import numpy as np
import pandas as pd

# Create minimal test data
times = pd.date_range('2020-01-01', periods=10)
taxis = np.linspace(-50, 50, 100)

# Test with actual code's variables and logic
pairs = [('STA1', 'STA2'), ('STA3', 'STA3')]
filters = [type('Filter', (), {'ref': '1'})]
components = 'ZZ'

for sta1, sta2 in pairs:
for f in filters:
filterid = int(f.ref)

# Create test datasets that will trigger our paths
if sta1 == 'STA1':
# Test error path with empty dataset
c = xr.Dataset()
is_valid, message = validate_stack_data(c, "reference")
if not is_valid:
logger.error(f"Invalid reference data for {sta1}:{sta2}-{components}-{filterid}: {message}")
continue
else:
# Test warning path with partial NaN data
data = np.random.random((len(times), len(taxis)))
data[0:5, :] = np.nan
da = xr.DataArray(data, coords=[times, taxis], dims=['times', 'taxis'])
c = da.to_dataset(name='CCF')

is_valid, message = validate_stack_data(c, "reference")
if "Warning" in message:
logger.warning(f"{sta1}:{sta2}-{components}-{filterid}: {message}")

@pytest.mark.order(100)
def test_100_plot_interferogram():
db = connect()
Expand Down

0 comments on commit 06725f4

Please sign in to comment.