Skip to content

Commit

Permalink
Merge pull request #44 from LSSTDESC/issue/43/move_open_model
Browse files Browse the repository at this point in the history
move open_model from init to run
  • Loading branch information
eacharles authored Aug 15, 2023
2 parents ff3d6c1 + 0e5e981 commit 03c07ba
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
16 changes: 7 additions & 9 deletions src/rail/estimation/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rail.estimation.informer import CatInformer
# for backwards compatibility


class CatEstimator(RailStage):
"""The base class for making photo-z posterior estimates from catalog-like inputs
(i.e., tables with fluxes in photometric bands among the set of columns)
Expand All @@ -37,9 +38,6 @@ def __init__(self, args, comm=None):
RailStage.__init__(self, args, comm=comm)
self._output_handle = None
self.model = None
if not isinstance(args, dict): #pragma: no cover
args = vars(args)
self.open_model(**args)

def open_model(self, **kwargs):
"""Load the mode and/or attach it to this Estimator
Expand Down Expand Up @@ -100,6 +98,9 @@ def estimate(self, input_data):
return self.get_handle('output')

def run(self):

self.open_model(**self.config)

iterator = self.input_iterator('input')
first = True
self._initialize_run()
Expand All @@ -120,14 +121,11 @@ def _finalize_run(self):
self._output_handle.finalize_write()

def _process_chunk(self, start, end, data, first):
raise NotImplementedError(f"{self.name}._process_chunk is not implemented") #pragma: no cover
raise NotImplementedError(f"{self.name}._process_chunk is not implemented") # pragma: no cover

def _do_chunk_output(self, qp_dstn, start, end, first):
if first:
self._output_handle = self.add_handle('output', data = qp_dstn)
self._output_handle.initialize_write(self._input_length, communicator = self.comm)
self._output_handle = self.add_handle('output', data=qp_dstn)
self._output_handle.initialize_write(self._input_length, communicator=self.comm)
self._output_handle.set_data(qp_dstn, partial=True)
self._output_handle.write_chunk(start, end)



12 changes: 5 additions & 7 deletions src/rail/estimation/summarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# for backwards compatibility


class CatSummarizer(RailStage):
class CatSummarizer(RailStage):
"""The base class for classes that go from catalog-like tables
to ensemble NZ estimates.
Expand Down Expand Up @@ -107,7 +107,7 @@ def summarize(self, input_data):
return self.get_handle('output')


class SZPZSummarizer(RailStage):
class SZPZSummarizer(RailStage):
"""The base class for classes that use two sets of data: a photometry sample with
spec-z values, and a photometry sample with unknown redshifts, e.g. minisom_som and
outputs a QP Ensemble with bootstrap realization of the N(z) distribution
Expand All @@ -124,9 +124,9 @@ def __init__(self, args, comm=None):
"""Initialize Estimator that can sample galaxy data."""
RailStage.__init__(self, args, comm=comm)
self.model = None
if not isinstance(args, dict): #pragma: no cover
args = vars(args)
self.open_model(**args)
# NOTE: open model removed from init, need to put an
# `open_model` call explicitly in the run method for
# each summarizer.

def open_model(self, **kwargs):
"""Load the mode and/or attach it to this Summarizer
Expand Down Expand Up @@ -188,5 +188,3 @@ def summarize(self, input_data, spec_data):
self.run()
self.finalize()
return self.get_handle('output')


0 comments on commit 03c07ba

Please sign in to comment.