Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update target timeseries minimum collection date #253

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions src/get_target_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,6 @@ def set_sequence_as_of(ctx, param, value):
return value


def set_collection_min_date(ctx, param, value):
"""Set the collection_min_date default value to nowcast date minus 31 days."""
if value is None:
nowcast_date = ctx.params.get("nowcast_date")
value = nowcast_date + timedelta(days=-31)
value = value.replace(hour=23, minute=59, second=59, tzinfo=timezone.utc)
return value


def set_collection_max_date(ctx, param, value):
"""Set the collection_max_date default value to nowcast date plus 10 days."""
if value is None:
Expand Down Expand Up @@ -169,7 +160,7 @@ def set_collection_max_date(ctx, param, value):
type=click.DateTime(formats=["%Y-%m-%d"]),
required=False,
default=None,
callback=set_collection_min_date,
callback=normalize_date,
help="Assign clades to sequences collected on or after this UTC date (YYYY-MM-DD). Default is the nowcast date minus 31 days.",
)
@click.option(
Expand All @@ -195,8 +186,6 @@ def main(
collection_max_date: datetime,
target_data_dir: Path,
) -> tuple[Path, Path]:
"""Get clade counts and save to S3 bucket."""

# Date for retrieving sequences cannot be in the future
if sequence_as_of > datetime.now(tz=timezone.utc):
logger.info(
Expand Down Expand Up @@ -232,6 +221,9 @@ def main(
else:
tree_as_of = datetime.fromisoformat(modeled_clades["meta"]["created_at"])

if collection_min_date is None:
collection_min_date = tree_as_of - timedelta(days=90)

assignments = assign_clades(
nowcast_date,
sequence_as_of,
Expand Down Expand Up @@ -343,6 +335,12 @@ def create_target_data(

oracle_output = (
time_series_all.select(["location", "target_date", "clade", "observation"])
# for oracle output, include only sequence collection dates that are >=
# nowcast_date - 31 days
.filter(
pl.col("target_date")
>= datetime.fromisoformat(nowcast_string) - timedelta(days=31)
)
.with_columns(pl.lit(nowcast_string).alias("nowcast_date"))
.rename({"observation": "oracle_value"})
)
Expand Down Expand Up @@ -417,10 +415,10 @@ def mock_command():
result = normalize_date(ctx, param, datetime(2024, 11, 11, 11, 11, 11))
assert result == datetime(2024, 11, 11, 23, 59, 59, tzinfo=timezone.utc)

# default collection_min_date is 31 days before the nowcast date
# if collection_min_date is provided, it should be set to end of day UTC
param = Option(["--collection-min-date"])
result = set_collection_min_date(ctx, param, None)
assert result == datetime(2024, 9, 1, 23, 59, 59, tzinfo=timezone.utc)
result = normalize_date(ctx, param, datetime(2024, 11, 11, 11, 11, 11))
assert result == datetime(2024, 11, 11, 23, 59, 59, tzinfo=timezone.utc)

# default collection_max_date is 10 days after the nowcast date
param = Option(["--collection-max-date"])
Expand Down Expand Up @@ -480,18 +478,18 @@ def test_target_data():
"clade_nextstrain": ["AA", "BB", "CC", "DD", "BB"],
"count": [2, 3, 4, 5, 6],
}
test_assignments: Clade = Clade(
test_assignments = Clade(
{"tree_as_of": datetime(2024, 8, 1, 14, 30, 40)},
pl.LazyFrame(),
pl.LazyFrame(test_summary),
) # type: ignore
)
test_clade_list = ["AA", "BB", "other"]
test_min_date = datetime(2024, 11, 30, tzinfo=timezone.utc)
test_max_date = datetime(2024, 12, 4, tzinfo=timezone.utc)
time_series, oracle = create_target_data(
test_assignments,
test_clade_list,
"2024-9-11",
"2024-09-11",
"2024-12-17",
test_min_date,
test_max_date,
Expand All @@ -518,7 +516,7 @@ def test_target_data():
assert ts.get_column("target_date").min() == date(2024, 11, 30)
assert ts.get_column("target_date").max() == date(2024, 12, 4)
assert ts.get_column("observation").sum() == 20
assert ts.get_column("nowcast_date").unique().to_list() == ["2024-9-11"]
assert ts.get_column("nowcast_date").unique().to_list() == ["2024-09-11"]
assert ts.get_column("sequence_as_of").unique().to_list() == ["2024-12-17"]
assert ts.get_column("tree_as_of").unique().to_list() == ["2024-08-01"]

Expand All @@ -541,7 +539,8 @@ def test_target_data():
def test_target_data_integration(caplog, tmp_path):
"""
If the modeled-clades file doesn't have meta.created_at, tree_as_of should default to
nowcast_date - two days.
nowcast_date - two days. Additionally, when collection_min_date isn't provided,
it should default to tree_as_of - 90 days.
"""
caplog.set_level(logging.INFO)

Expand All @@ -563,15 +562,25 @@ def test_target_data_integration(caplog, tmp_path):
ts = pl.read_parquet(result.return_value[0])

# sequence date should default to nowcast_date + 90 days
assert "sequence_as_of=2024-12-10" in caplog.text
assert "sequence_as_of=2024-12-10" in caplog.text.lower()
# tasks.json for 2024-09-11 doesn't have a meta.created_at field, so tree_as_of = nowcast_date - 2 days
assert "tree_as_of=2024-09-09" in caplog.text
assert "tree_as_of=2024-09-09" in caplog.text.lower()

# time series target dates should be limited to dates that match collection date min/max options
# (2024-08-11 to 2024-09-21 is 43 days, but both have 11:59:59 timestamps, so we'd expect
# 42 days in the time series)
# number of unique dates in the time series target should be the number of
# days between collection_min_date (tree_as_of - 90 days) and the
# collection_max_date (nowcast_date + 10 days), inclusive
nowcast_datetime = datetime.fromisoformat(nowcast_date).replace(
hour=11, minute=59, second=59, tzinfo=timezone.utc
)
tree_as_of_datetime = datetime.fromisoformat("2024-09-09").replace(
hour=11, minute=59, second=59, tzinfo=timezone.utc
)
expected_num_days = (
(nowcast_datetime + timedelta(days=10))
- (tree_as_of_datetime - timedelta(days=90))
).days + 1
target_dates = ts["target_date"].unique().to_list()
assert len(target_dates) == 42
assert len(target_dates) == expected_num_days

modeled_clades_path = Path("auxiliary-data/modeled-clades") / f"{nowcast_date}.json"
modeled_clades_json = json.loads(modeled_clades_path.read_text(encoding="utf-8"))
Expand All @@ -596,7 +605,21 @@ def test_target_data_integration(caplog, tmp_path):
len(target_dates) * len(state_list) * len(modeled_clades) == ts.height

oracle = pl.read_parquet(result.return_value[1])
assert oracle.height == ts.height

oracle_min_date = oracle["target_date"].min()
oracle_max_date = oracle["target_date"].max()
assert (
oracle_min_date
== (datetime.strptime(nowcast_date, "%Y-%m-%d") - timedelta(days=31)).date()
)
assert (
oracle_max_date
== (datetime.strptime(nowcast_date, "%Y-%m-%d") + timedelta(days=10)).date()
)

# oracle series rows should = number of oracle target dates * total locations * total clades
expected_num_days = (oracle_max_date - oracle_min_date).days + 1
assert oracle.height == expected_num_days * len(state_list) * len(modeled_clades)

oracle_clades = oracle["clade"].unique().to_list()
assert len(modeled_clades) == len(oracle_clades)
Expand Down
Loading