Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix up use of chunk_size in var_inf to avoid overwriting config param… #126

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/rail/core/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,11 @@ def input_iterator(self, tag, **kwargs):
except Exception:
groupname = None

chunk_size = kwargs.get('chunk_size', self.config.chunk_size)

if handle.path and handle.path != "None": # pylint: disable=no-else-return
self._input_length = handle.size(groupname=groupname)
total_chunks_needed = ceil(self._input_length / self.config.chunk_size)
total_chunks_needed = ceil(self._input_length / chunk_size)
# If the number of process is larger than we need, we wemove some of them
if total_chunks_needed < self.size: # pragma: no cover
if self.comm:
Expand All @@ -361,7 +363,7 @@ def input_iterator(self, tag, **kwargs):
sys.exit()
kwcopy = dict(
groupname=groupname,
chunk_size=self.config.chunk_size,
chunk_size=chunk_size,
rank=self.rank,
parallel_size=self.size,
)
Expand Down
20 changes: 13 additions & 7 deletions src/rail/estimation/algos/var_inf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,25 @@ def __init__(self, args, comm=None):
PZSummarizer.__init__(self, args, comm=comm)
self.zgrid = None

def run(self):
# Redefining the chunk size so that all of the data is distributed at once in the
# nodes. This would fill all the memory if not enough nodes are allocated

input_data = self.get_handle("input", allow_missing=True)
def _setup_iterator(self):
input_handle = self.get_handle("input", allow_missing=True)
try:
self.config.hdf5_groupname
except Exception:
self.config.hdf5_groupname = None
input_length = input_data.size(groupname=self.config.hdf5_groupname)
self.config.chunk_size = np.ceil(input_length / self.size)
input_length = input_handle.size(groupname=self.config.hdf5_groupname)
chunk_size = int(np.ceil(input_length / self.size))

iterator = self.input_iterator("input", chunk_size=chunk_size)
return iterator


def run(self):
# Redefining the chunk size so that all of the data is distributed at once in the
# nodes. This would fill all the memory if not enough nodes are allocated

iterator = self.input_iterator("input")
iterator = self._setup_iterator()
self.zgrid = np.linspace(self.config.zmin, self.config.zmax, self.config.nzbins)
first = True
for s, e, test_data in iterator:
Expand Down
Loading