Skip to content

Commit

Permalink
Fix up use of chunk_size in var_inf to avoid overwriting config param…
Browse files Browse the repository at this point in the history
…eter
  • Loading branch information
eacharles committed Jul 3, 2024
1 parent 0222836 commit 7eaa6a3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
6 changes: 4 additions & 2 deletions src/rail/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,11 @@ def input_iterator(self, tag, **kwargs):
except Exception:
groupname = None

chunk_size = kwargs.get('chunk_size', self.config.chunk_size)

if handle.path and handle.path != "None": # pylint: disable=no-else-return
self._input_length = handle.size(groupname=groupname)
total_chunks_needed = ceil(self._input_length / self.config.chunk_size)
total_chunks_needed = ceil(self._input_length / chunk_size)
# If the number of process is larger than we need, we wemove some of them
if total_chunks_needed < self.size: # pragma: no cover
if self.comm:
Expand All @@ -361,7 +363,7 @@ def input_iterator(self, tag, **kwargs):
sys.exit()
kwcopy = dict(
groupname=groupname,
chunk_size=self.config.chunk_size,
chunk_size=chunk_size,
rank=self.rank,
parallel_size=self.size,
)
Expand Down
20 changes: 13 additions & 7 deletions src/rail/estimation/algos/var_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ def __init__(self, args, comm=None):
PZSummarizer.__init__(self, args, comm=comm)
self.zgrid = None

def run(self):
# Redefining the chunk size so that all of the data is distributed at once in the
# nodes. This would fill all the memory if not enough nodes are allocated

input_data = self.get_handle("input", allow_missing=True)
def _setup_iterator(self):
input_handle = self.get_handle("input", allow_missing=True)
try:
self.config.hdf5_groupname
except Exception:
self.config.hdf5_groupname = None
input_length = input_data.size(groupname=self.config.hdf5_groupname)
self.config.chunk_size = np.ceil(input_length / self.size)
input_length = input_handle.size(groupname=self.config.hdf5_groupname)
chunk_size = int(np.ceil(input_length / self.size))

iterator = self.input_iterator("input", chunk_size=chunk_size)
return iterator


def run(self):
# Redefining the chunk size so that all of the data is distributed at once in the
# nodes. This would fill all the memory if not enough nodes are allocated

iterator = self.input_iterator("input")
iterator = self._setup_iterator()
self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)
first = True
for s, e, test_data in iterator:
Expand Down

0 comments on commit 7eaa6a3

Please sign in to comment.