From a95ba68dd423a6b74d5fb26d54a749a0fa3bf474 Mon Sep 17 00:00:00 2001 From: juacrumar Date: Sat, 23 Nov 2024 07:59:00 -0500 Subject: [PATCH] accept the possibility of composing variants ; this will allow the possibility of studying the effect of new theories -which might be differently architectured in some ways- in legacy datasets --- validphys2/src/validphys/commondataparser.py | 21 ++++++++++++--- validphys2/src/validphys/config.py | 5 +++- validphys2/src/validphys/core.py | 28 +++++++++++++++++--- validphys2/src/validphys/loader.py | 3 ++- validphys2/src/validphys/overfit_metric.py | 14 +++------- 5 files changed, 52 insertions(+), 19 deletions(-) diff --git a/validphys2/src/validphys/commondataparser.py b/validphys2/src/validphys/commondataparser.py index 9f96d678cf..cc9ba23ac0 100644 --- a/validphys2/src/validphys/commondataparser.py +++ b/validphys2/src/validphys/commondataparser.py @@ -454,8 +454,16 @@ def check(self): def apply_variant(self, variant_name): """Return a new instance of this class with the variant applied - This class also defines how the variant is applied to the commondata + This class also defines how the variant is applied to the commondata. + If more than a variant is being used, this function will be called recursively + until all variants are applied. """ + if not isinstance(variant_name, str): + observable = self + for single_variant in variant_name: + observable = observable.apply_variant(single_variant) + return observable + try: variant = self.variants[variant_name] except KeyError as e: @@ -471,7 +479,6 @@ def apply_variant(self, variant_name): # This section should only be used for the purposes of reproducibility # of legacy data, no new data should use these - if variant.experiment is not None: new_nnpdf_metadata = dict(self._parent.nnpdf_metadata.items()) new_nnpdf_metadata["experiment"] = variant.experiment @@ -479,7 +486,13 @@ def apply_variant(self, variant_name): variant_replacement["_parent"] = setmetadata_copy variant_replacement["plotting"] = dataclasses.replace(self.plotting) - return dataclasses.replace(self, applied_variant=variant_name, **variant_replacement) + # Keep track of applied variants: + if self.applied_variant is None: + varname = variant_name + else: + varname = f"{self.applied_variant}_{variant_name}" + + return dataclasses.replace(self, applied_variant=varname, **variant_replacement) @property def is_positivity(self): @@ -834,7 +847,7 @@ def parse_new_metadata(metadata_file, observable_name, variant=None): # Select one observable from the entire metadata metadata = set_metadata.select_observable(observable_name) - # And apply variant if given + # And apply variant or variants if given if variant is not None: metadata = metadata.apply_variant(variant) diff --git a/validphys2/src/validphys/config.py b/validphys2/src/validphys/config.py index 10e4f009f5..6565497d8d 100644 --- a/validphys2/src/validphys/config.py +++ b/validphys2/src/validphys/config.py @@ -460,14 +460,17 @@ def parse_dataset_input(self, dataset: Mapping): if variant is None or map_variant == "legacy_dw": variant = map_variant + if sysnum is not None: + log.warning("The key 'sys' is deprecated and will soon be removed") + return DataSetInput( name=name, - sys=sysnum, cfac=cfac, frac=frac, weight=weight, custom_group=custom_group, variant=variant, + sys=sysnum, ) def parse_use_fitcommondata(self, do_use: bool): diff --git a/validphys2/src/validphys/core.py b/validphys2/src/validphys/core.py index ceb0343feb..18b2daee21 100644 --- a/validphys2/src/validphys/core.py +++ b/validphys2/src/validphys/core.py @@ -368,17 +368,39 @@ def plot_kinlabels(self): class DataSetInput(TupleComp): """Represents whatever the user enters in the YAML to specify a - dataset.""" + dataset. + + name: str + name of the dataset_inputs + cfac: tuple + cfactors to apply to the final predictions (default: ()) + frac: float + fraction of the data to be used during training (default: 1.0) + weight: float + extra weight to apply to the dataset (default: 1.0) + variant: str or tuple[str] + variant or variants to apply (default: None) + sysnum: int + deprecated, systematic file to load for the dataset + """ - def __init__(self, *, name, sys, cfac, frac, weight, custom_group, variant): + def __init__(self, *, name, cfac, frac, weight, custom_group, variant, sys=None): self.name = name self.sys = sys self.cfac = cfac self.frac = frac self.weight = weight self.custom_group = custom_group + + # Parse the variant if introduced as a string + if isinstance(variant, str): + variant = (variant,) + + # Make sure that variant is not a list but, in case, a tuple + if isinstance(variant, list): + variant = tuple(variant) self.variant = variant - super().__init__(name, sys, cfac, frac, weight, custom_group, variant) + super().__init__(name, cfac, frac, weight, custom_group, variant, sys) def __str__(self): return self.name diff --git a/validphys2/src/validphys/loader.py b/validphys2/src/validphys/loader.py index e8312b6c10..531a8e1ebb 100644 --- a/validphys2/src/validphys/loader.py +++ b/validphys2/src/validphys/loader.py @@ -380,7 +380,7 @@ def check_commondata( self, setname, sysnum=None, use_fitcommondata=False, fit=None, variant=None ): """Prepare the commondata files to be loaded. - A commondata is defined by its name (``setname``) and the variant (``variant``) + A commondata is defined by its name (``setname``) and the variant(s) (``variant``) At the moment both old-format and new-format commondata can be utilized and loaded however old-format commondata are deprecated and will be removed in future relases. @@ -483,6 +483,7 @@ def get_commondata(self, setname, sysnum): """Get a Commondata from the set name and number.""" # TODO: check where this is used # as this might ignore cfactors or variants + raise Exception("Not used") cd = self.check_commondata(setname, sysnum) return cd.load() diff --git a/validphys2/src/validphys/overfit_metric.py b/validphys2/src/validphys/overfit_metric.py index 7c77d3827c..c9a69a1dcd 100644 --- a/validphys2/src/validphys/overfit_metric.py +++ b/validphys2/src/validphys/overfit_metric.py @@ -1,7 +1,7 @@ """ overfit_metric.py -This module contains the functions used to calculate the overfit metric and +This module contains the functions used to calculate the overfit metric and produce the corresponding tables and figures. """ @@ -59,7 +59,7 @@ def calculate_chi2s_per_replica( preds : list[pd.core.frame.DataFrame] List of pandas dataframes, each containing the predictions of the pdf replicas for a dataset_input - dataset_inputs : list[DatasetInput] + dataset_inputs : list[DataSetInput] groups_covmat_no_table : pdf.core.frame.DataFrame Returns @@ -112,10 +112,7 @@ def calculate_chi2s_per_replica( def array_expected_overfitting( - calculate_chi2s_per_replica, - replica_data, - number_of_resamples=1000, - resampling_fraction=0.95, + calculate_chi2s_per_replica, replica_data, number_of_resamples=1000, resampling_fraction=0.95 ): """Calculates the expected difference in chi2 between: 1. The chi2 of a PDF replica calculated using the corresponding pseudodata @@ -181,10 +178,7 @@ def plot_overfitting_histogram(fit, array_expected_overfitting): ax.hist(array_expected_overfitting, bins=50, density=True) ax.axvline(x=mean, color="black") ax.axvline(x=0, color="black", linestyle="--") - xrange = [ - array_expected_overfitting.min(), - array_expected_overfitting.max(), - ] + xrange = [array_expected_overfitting.min(), array_expected_overfitting.max()] xgrid = np.linspace(xrange[0], xrange[1], num=100) ax.plot(xgrid, stats.norm.pdf(xgrid, mean, std)) ax.set_xlabel(r"$\mathcal{R}_O$")