diff --git a/src/rail/core/stage.py b/src/rail/core/stage.py index b11f8a87..5bdfd9f9 100644 --- a/src/rail/core/stage.py +++ b/src/rail/core/stage.py @@ -344,9 +344,11 @@ def input_iterator(self, tag, **kwargs): except Exception: groupname = None + chunk_size = kwargs.get('chunk_size', self.config.chunk_size) + if handle.path and handle.path != "None": # pylint: disable=no-else-return self._input_length = handle.size(groupname=groupname) - total_chunks_needed = ceil(self._input_length / self.config.chunk_size) + total_chunks_needed = ceil(self._input_length / chunk_size) # If the number of process is larger than we need, we wemove some of them if total_chunks_needed < self.size: # pragma: no cover if self.comm: @@ -361,7 +363,7 @@ def input_iterator(self, tag, **kwargs): sys.exit() kwcopy = dict( groupname=groupname, - chunk_size=self.config.chunk_size, + chunk_size=chunk_size, rank=self.rank, parallel_size=self.size, ) diff --git a/src/rail/estimation/algos/var_inf.py b/src/rail/estimation/algos/var_inf.py index b4abf2f2..19f73688 100644 --- a/src/rail/estimation/algos/var_inf.py +++ b/src/rail/estimation/algos/var_inf.py @@ -67,19 +67,25 @@ def __init__(self, args, comm=None): PZSummarizer.__init__(self, args, comm=comm) self.zgrid = None - def run(self): - # Redefining the chunk size so that all of the data is distributed at once in the - # nodes. This would fill all the memory if not enough nodes are allocated - input_data = self.get_handle("input", allow_missing=True) + def _setup_iterator(self): + input_handle = self.get_handle("input", allow_missing=True) try: self.config.hdf5_groupname except Exception: self.config.hdf5_groupname = None - input_length = input_data.size(groupname=self.config.hdf5_groupname) - self.config.chunk_size = np.ceil(input_length / self.size) + input_length = input_handle.size(groupname=self.config.hdf5_groupname) + chunk_size = int(np.ceil(input_length / self.size)) + + iterator = self.input_iterator("input", chunk_size=chunk_size) + return iterator + + + def run(self): + # Redefining the chunk size so that all of the data is distributed at once in the + # nodes. This would fill all the memory if not enough nodes are allocated - iterator = self.input_iterator("input") + iterator = self._setup_iterator() self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins) first = True for s, e, test_data in iterator: