Skip to content

Commit

Permalink
update target timeseries minimum collection date (#253)
Browse files Browse the repository at this point in the history
* change collection_min default to tree_as_of -  90 days

* change collect_min default in oracle output

* Update src/get_target_data.py

Co-authored-by: Evan Ray <[email protected]>

---------

Co-authored-by: Evan Ray <[email protected]>
  • Loading branch information
bsweger and elray1 authored Jan 8, 2025
1 parent 4586387 commit 1920829
Showing 1 changed file with 50 additions and 27 deletions.
77 changes: 50 additions & 27 deletions src/get_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def set_sequence_as_of(ctx, param, value):
return value


def set_collection_min_date(ctx, param, value):
"""Set the collection_min_date default value to nowcast date minus 31 days."""
if value is None:
nowcast_date = ctx.params.get("nowcast_date")
value = nowcast_date + timedelta(days=-31)
value = value.replace(hour=23, minute=59, second=59, tzinfo=timezone.utc)
return value


def set_collection_max_date(ctx, param, value):
"""Set the collection_max_date default value to nowcast date plus 10 days."""
if value is None:
Expand Down Expand Up @@ -169,7 +160,7 @@ def set_collection_max_date(ctx, param, value):
type=click.DateTime(formats=["%Y-%m-%d"]),
required=False,
default=None,
callback=set_collection_min_date,
callback=normalize_date,
help="Assign clades to sequences collected on or after this UTC date (YYYY-MM-DD). Default is the nowcast date minus 31 days.",
)
@click.option(
Expand All @@ -195,8 +186,6 @@ def main(
collection_max_date: datetime,
target_data_dir: Path,
) -> tuple[Path, Path]:
"""Get clade counts and save to S3 bucket."""

# Date for retrieving sequences cannot be in the future
if sequence_as_of > datetime.now(tz=timezone.utc):
logger.info(
Expand Down Expand Up @@ -232,6 +221,9 @@ def main(
else:
tree_as_of = datetime.fromisoformat(modeled_clades["meta"]["created_at"])

if collection_min_date is None:
collection_min_date = tree_as_of - timedelta(days=90)

assignments = assign_clades(
nowcast_date,
sequence_as_of,
Expand Down Expand Up @@ -343,6 +335,12 @@ def create_target_data(

oracle_output = (
time_series_all.select(["location", "target_date", "clade", "observation"])
# for oracle output, include only sequence collection dates that are >=
# nowcast_date - 31 days
.filter(
pl.col("target_date")
>= datetime.fromisoformat(nowcast_string) - timedelta(days=31)
)
.with_columns(pl.lit(nowcast_string).alias("nowcast_date"))
.rename({"observation": "oracle_value"})
)
Expand Down Expand Up @@ -417,10 +415,10 @@ def mock_command():
result = normalize_date(ctx, param, datetime(2024, 11, 11, 11, 11, 11))
assert result == datetime(2024, 11, 11, 23, 59, 59, tzinfo=timezone.utc)

# default collection_min_date is 31 days before the nowcast date
# if collection_min_date is provided, it should be set to end of day UTC
param = Option(["--collection-min-date"])
result = set_collection_min_date(ctx, param, None)
assert result == datetime(2024, 9, 1, 23, 59, 59, tzinfo=timezone.utc)
result = normalize_date(ctx, param, datetime(2024, 11, 11, 11, 11, 11))
assert result == datetime(2024, 11, 11, 23, 59, 59, tzinfo=timezone.utc)

# default collection_max_date is 10 days after the nowcast date
param = Option(["--collection-max-date"])
Expand Down Expand Up @@ -480,18 +478,18 @@ def test_target_data():
"clade_nextstrain": ["AA", "BB", "CC", "DD", "BB"],
"count": [2, 3, 4, 5, 6],
}
test_assignments: Clade = Clade(
test_assignments = Clade(
{"tree_as_of": datetime(2024, 8, 1, 14, 30, 40)},
pl.LazyFrame(),
pl.LazyFrame(test_summary),
) # type: ignore
)
test_clade_list = ["AA", "BB", "other"]
test_min_date = datetime(2024, 11, 30, tzinfo=timezone.utc)
test_max_date = datetime(2024, 12, 4, tzinfo=timezone.utc)
time_series, oracle = create_target_data(
test_assignments,
test_clade_list,
"2024-9-11",
"2024-09-11",
"2024-12-17",
test_min_date,
test_max_date,
Expand All @@ -518,7 +516,7 @@ def test_target_data():
assert ts.get_column("target_date").min() == date(2024, 11, 30)
assert ts.get_column("target_date").max() == date(2024, 12, 4)
assert ts.get_column("observation").sum() == 20
assert ts.get_column("nowcast_date").unique().to_list() == ["2024-9-11"]
assert ts.get_column("nowcast_date").unique().to_list() == ["2024-09-11"]
assert ts.get_column("sequence_as_of").unique().to_list() == ["2024-12-17"]
assert ts.get_column("tree_as_of").unique().to_list() == ["2024-08-01"]

Expand All @@ -541,7 +539,8 @@ def test_target_data():
def test_target_data_integration(caplog, tmp_path):
"""
If the modeled-clades file doesn't have meta.created_at, tree_as_of should default to
nowcast_date - two days.
nowcast_date - two days. Additionally, when collection_min_date isn't provided,
it should default to tree_as_of - 90 days.
"""
caplog.set_level(logging.INFO)

Expand All @@ -563,15 +562,25 @@ def test_target_data_integration(caplog, tmp_path):
ts = pl.read_parquet(result.return_value[0])

# sequence date should default to nowcast_date + 90 days
assert "sequence_as_of=2024-12-10" in caplog.text
assert "sequence_as_of=2024-12-10" in caplog.text.lower()
# tasks.json for 2024-09-11 doesn't have a meta.created_at field, so tree_as_of = nowcast_date - 2 days
assert "tree_as_of=2024-09-09" in caplog.text
assert "tree_as_of=2024-09-09" in caplog.text.lower()

# time series target dates should be limited to dates that match collection date min/max options
# (2024-08-11 to 2024-09-21 is 43 days, but both have 11:59:59 timestamps, so we'd expect
# 42 days in the time series)
# number of unique dates in the time series target should be the number of
# days between collection_min_date (tree_as_of - 90 days) and the
# collection_max_date (nowcast_date + 10 days), inclusive
nowcast_datetime = datetime.fromisoformat(nowcast_date).replace(
hour=11, minute=59, second=59, tzinfo=timezone.utc
)
tree_as_of_datetime = datetime.fromisoformat("2024-09-09").replace(
hour=11, minute=59, second=59, tzinfo=timezone.utc
)
expected_num_days = (
(nowcast_datetime + timedelta(days=10))
- (tree_as_of_datetime - timedelta(days=90))
).days + 1
target_dates = ts["target_date"].unique().to_list()
assert len(target_dates) == 42
assert len(target_dates) == expected_num_days

modeled_clades_path = Path("auxiliary-data/modeled-clades") / f"{nowcast_date}.json"
modeled_clades_json = json.loads(modeled_clades_path.read_text(encoding="utf-8"))
Expand All @@ -596,7 +605,21 @@ def test_target_data_integration(caplog, tmp_path):
len(target_dates) * len(state_list) * len(modeled_clades) == ts.height

oracle = pl.read_parquet(result.return_value[1])
assert oracle.height == ts.height

oracle_min_date = oracle["target_date"].min()
oracle_max_date = oracle["target_date"].max()
assert (
oracle_min_date
== (datetime.strptime(nowcast_date, "%Y-%m-%d") - timedelta(days=31)).date()
)
assert (
oracle_max_date
== (datetime.strptime(nowcast_date, "%Y-%m-%d") + timedelta(days=10)).date()
)

# oracle series rows should = number of oracle target dates * total locations * total clades
expected_num_days = (oracle_max_date - oracle_min_date).days + 1
assert oracle.height == expected_num_days * len(state_list) * len(modeled_clades)

oracle_clades = oracle["clade"].unique().to_list()
assert len(modeled_clades) == len(oracle_clades)
Expand Down

0 comments on commit 1920829

Please sign in to comment.