Skip to content

Commit

Permalink
Fixes to masked summarizers
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Jun 28, 2024
1 parent 05f9a17 commit 2d53de7
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/rail/estimation/algos/true_nz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import qp

from ceci.config import StageParameter as Param
from rail.core.common_params import SHARED_PARAMS
from rail.core.stage import RailStage
from rail.core.data import QPHandle, TableHandle
Expand All @@ -20,6 +21,9 @@ class TrueNZHistogrammer(RailStage):
zmax=SHARED_PARAMS,
nzbins=SHARED_PARAMS,
redshift_col=SHARED_PARAMS,
selected_bin=Param(int, -1, msg="Which tomography bin to consider"),
chunk_size=10000,
hdf5_groupname="",
)
inputs = [("input", TableHandle), ("tomography_bins", TableHandle)]
outputs = [("true_NZ", QPHandle)]
Expand All @@ -30,8 +34,12 @@ def __init__(self, args, comm=None):
self.bincents = None

def _setup_iterator(self):
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

itrs = [
self.input_iterator('input'),
self.input_iterator('tomography_bins'),
]

for it in zip(*itrs):
first = True
mask = None
Expand All @@ -43,7 +51,7 @@ def _setup_iterator(self):
first = False
else:
if self.config.selected_bin < 0:
mask = np.ones(len(d))
mask = np.ones(e-s, dtype=bool)
else:
mask = d['class_id'] == self.config.selected_bin
yield start, end, pz_data, mask
Expand Down Expand Up @@ -76,5 +84,6 @@ def run(self):
def _process_chunk(
self, _start, _end, data, mask, _first, single_hist,
):
zb = data[self.config.redshift_col][mask]
squeeze_mask = np.squeeze(mask)
zb = data[self.config.redshift_col][squeeze_mask]
single_hist += np.histogram(zb, bins=self.zgrid)[0]

0 comments on commit 2d53de7

Please sign in to comment.