From 6c1bd66849ce8bceab25e6c38c0334b690b2eb22 Mon Sep 17 00:00:00 2001 From: Michael Jasper Martins Date: Tue, 21 Jan 2025 17:32:43 +0100 Subject: [PATCH] Added Warning for previously set rescale_parameters --- bilby/core/prior/joint.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/bilby/core/prior/joint.py b/bilby/core/prior/joint.py index 9d234934..cd2d0c62 100644 --- a/bilby/core/prior/joint.py +++ b/bilby/core/prior/joint.py @@ -332,18 +332,18 @@ def rescale(self, value, **kwargs): """ if value is None: samp = np.array(list(self._current_unit_cube_parameter_values.values())).T + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) else: - for key, val in zip(self.names, value): - self.set_rescale(key, val) samp = np.asarray(value) - - if len(samp.shape) == 1: - samp = samp.reshape(1, self.num_vars) - - if len(samp.shape) != 2: - raise ValueError("Array is the wrong shape") - elif samp.shape[1] != self.num_vars: - raise ValueError("Array is the wrong shape") + if len(samp.shape) == 1: + samp = samp.reshape(1, self.num_vars) + if len(samp.shape) != 2: + raise ValueError("Array is the wrong shape") + elif samp.shape[1] != self.num_vars: + raise ValueError("Array is the wrong shape") + for key, val in zip(self.names, samp.T): + self.set_rescale(key, val) samp = self._rescale(samp, **kwargs) for i, key in enumerate(self.names): @@ -821,6 +821,18 @@ def rescale(self, val, **kwargs): the rescaled value once all parameters have been requested """ + if self.dist.get_rescaled(self.name) is not None: + import warnings + warnings.warn( + f"Rescale values for {self.name} in {self.dist} have already been set, " + "indicating that another rescale-operation is in progress.\n" + "Call dist.reset_rescale() on the joint prior distribution associated " + "with this prior after using dist.rescale() and make sure all parameters " + "necessary for rescaling have been requested in PriorDict.rescale().\n" + "Resetting now.", + RuntimeWarning + ) + self.dist.reset_rescale() self.dist.set_rescale(self.name, val) if self.dist.filled_rescale():