diff --git a/src/rail/estimation/algos/naive_stack.py b/src/rail/estimation/algos/naive_stack.py index 9d3c29d4..4d52c568 100644 --- a/src/rail/estimation/algos/naive_stack.py +++ b/src/rail/estimation/algos/naive_stack.py @@ -25,7 +25,7 @@ def run(self): self.add_data('model', np.array([None])) class NaiveStackSummarizer(PZSummarizer): - """Summarizer which simply histograms a point estimate + """Summarizer which stacks individual P(z) """ name = 'NaiveStackSummarizer' @@ -49,7 +49,7 @@ def run(self): # Initiallizing the stacking pdf's yvals = np.zeros((1, len(self.zgrid))) bvals = np.zeros((self.config.nsamples, len(self.zgrid))) - bootstrap_matrix = self.broadcast_bootstrap_matrix() + bootstrap_matrix = self._broadcast_bootstrap_matrix() first = True for s, e, test_data in iterator: @@ -57,7 +57,7 @@ def run(self): self._process_chunk(s, e, test_data, first, bootstrap_matrix, yvals, bvals) first = False if self.comm is not None: # pragma: no cover - bvals, yvals = self.join_histograms(bvals, yvals) + bvals, yvals = self._join_histograms(bvals, yvals) if self.rank == 0: sample_ens = qp.Ensemble(qp.interp, data=dict(xvals=self.zgrid, yvals=bvals)) @@ -77,24 +77,6 @@ def _process_chunk(self, start, end, data, first, bootstrap_matrix, yvals, bval bootstrap_draws = bootstrap_draws[mask] - start bvals[i] += np.sum(pdf_vals[bootstrap_draws], axis=0) - def broadcast_bootstrap_matrix(self): - rng = np.random.default_rng(seed=self.config.seed) - # Only one of the nodes needs to produce the bootstrap indices - ngal = self._input_length - print('i am the rank with number of galaxies',self.rank,ngal) - if self.rank == 0: - bootstrap_matrix = rng.integers(low=0, high=ngal, size=(ngal,self.config.nsamples)) - else: # pragma: no cover - bootstrap_matrix = None - if self.comm is not None: # pragma: no cover - self.comm.Barrier() - bootstrap_matrix = self.comm.bcast(bootstrap_matrix, root = 0) - return bootstrap_matrix - - def join_histograms(self, bvals, yvals): - bvals_r = self.comm.reduce(bvals) - yvals_r = self.comm.reduce(yvals) - return(bvals_r, yvals_r) diff --git a/src/rail/estimation/algos/point_est_hist.py b/src/rail/estimation/algos/point_est_hist.py index c20270f7..43ecae3a 100644 --- a/src/rail/estimation/algos/point_est_hist.py +++ b/src/rail/estimation/algos/point_est_hist.py @@ -46,23 +46,40 @@ def __init__(self, args, comm=None): self.zgrid = None self.bincents = None + def run(self): - rng = np.random.default_rng(seed=self.config.seed) - test_data = self.get_data('input') - npdf = test_data.npdf - zb = test_data.ancil['zmode'] - nsamp = self.config.nsamples + iterator = self.input_iterator('input') self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins + 1) self.bincents = 0.5 * (self.zgrid[1:] + self.zgrid[:-1]) - single_hist = np.histogram(test_data.ancil[self.config.point_estimate], bins=self.zgrid)[0] - qp_d = qp.Ensemble(qp.hist, + bootstrap_matrix = self._broadcast_bootstrap_matrix() + # Initiallizing the histograms + single_hist = np.zeros(self.config.nzbins) + hist_vals = np.zeros((self.config.nsamples, self.config.nzbins)) + + first = True + for s, e, test_data in iterator: + print(f"Process {self.rank} running estimator on chunk {s} - {e}") + self._process_chunk(s, e, test_data, first, bootstrap_matrix, single_hist, hist_vals) + first = False + if self.comm is not None: # pragma: no cover + hist_vals, single_hist = self._join_histograms(hist_vals, single_hist) + + if self.rank == 0: + sample_ens = qp.Ensemble(qp.hist, + data=dict(bins=self.zgrid, pdfs=np.atleast_2d(hist_vals))) + qp_d = qp.Ensemble(qp.hist, data=dict(bins=self.zgrid, pdfs=np.atleast_2d(single_hist))) - hist_vals = np.empty((nsamp, self.config.nzbins)) - for i in range(nsamp): - bootstrap_indeces = rng.integers(low=0, high=npdf, size=npdf) + self.add_data('output', sample_ens) + self.add_data('single_NZ', qp_d) + + def _process_chunk(self, start, end, test_data, first, bootstrap_matrix, single_hist, hist_vals): + zb = test_data.ancil[self.config.point_estimate] + single_hist += np.histogram(zb, bins=self.zgrid)[0] + for i in range(self.config.nsamples): + bootstrap_indeces = bootstrap_matrix[:, i] + # Neither all of the bootstrap_draws are in this chunk nor the index starts at "start" + mask = (bootstrap_indeces>=start) & (bootstrap_indeces