From 32ee455388cc5e4862dc1d21c4f79546e9569f6b Mon Sep 17 00:00:00 2001 From: Evan Ray Date: Tue, 19 Nov 2024 19:59:42 -0500 Subject: [PATCH] add option for location filter --- src/idmodels/gbqr.py | 2 ++ src/idmodels/sarix.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index d81bf97..b838d33 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -37,6 +37,8 @@ def run(self, run_config): flusurvnet_kwargs=flusurvnet_kwargs, sources=self.model_config.sources, power_transform=self.model_config.power_transform) + if run_config.locations is not None: + df = df.loc[df["location"].isin(run_config.locations)] # augment data with features and target values df, feat_names = create_features_and_targets( diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index 6bd885e..10ec04a 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -17,6 +17,8 @@ def run(self, run_config): df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date}, sources=self.model_config.sources, power_transform=self.model_config.power_transform) + if run_config.locations is not None: + df = df.loc[df["location"].isin(run_config.locations)] # season week relative to christmas df = df.merge(