Skip to content

Commit

Permalink
fix handling of horizons for sarix and gbqr models
Browse files Browse the repository at this point in the history
  • Loading branch information
elray1 committed Nov 19, 2024
1 parent a22b225 commit 06bcc51
Show file tree
Hide file tree
Showing 4 changed files with 956 additions and 956 deletions.
2 changes: 1 addition & 1 deletion src/idmodels/gbqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _format_as_flusight_output(self, preds_df, ref_date):

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
preds_df["reference_date"] = ref_date
preds_df["horizon"] = preds_df["horizon"] - 2
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - ref_date).dt.days / 7).astype(int)
preds_df["target"] = "wk inc flu hosp"

preds_df["output_type"] = "quantile"
Expand Down
2 changes: 1 addition & 1 deletion src/idmodels/sarix.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def run(self, run_config):

preds_df["target_end_date"] = preds_df["wk_end_date"] + pd.to_timedelta(7*preds_df["horizon"], unit="days")
preds_df["reference_date"] = run_config.ref_date
preds_df["horizon"] = preds_df["horizon"] - 2
preds_df["horizon"] = (pd.to_timedelta(preds_df["target_end_date"].dt.date - run_config.ref_date).dt.days / 7).astype(int)
preds_df["output_type"] = "quantile"
preds_df["target"] = "wk inc flu hosp"
preds_df.drop(columns="wk_end_date", inplace=True)
Expand Down
Loading

0 comments on commit 06bcc51

Please sign in to comment.