diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index d81bf97..b838d33 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -37,6 +37,8 @@ def run(self, run_config): flusurvnet_kwargs=flusurvnet_kwargs, sources=self.model_config.sources, power_transform=self.model_config.power_transform) + if run_config.locations is not None: + df = df.loc[df["location"].isin(run_config.locations)] # augment data with features and target values df, feat_names = create_features_and_targets( diff --git a/src/idmodels/sarix.py b/src/idmodels/sarix.py index 5a94a87..8775791 100644 --- a/src/idmodels/sarix.py +++ b/src/idmodels/sarix.py @@ -17,6 +17,8 @@ def run(self, run_config): df = fdl.load_data(nhsn_kwargs={"as_of": run_config.ref_date}, sources=self.model_config.sources, power_transform=self.model_config.power_transform) + if run_config.locations is not None: + df = df.loc[df["location"].isin(run_config.locations)] # season week relative to christmas df = df.merge(