Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: dynesty - reduce number of calls to add_live_points #872

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions bilby/core/sampler/dynamic_dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def sampler_class(self):
def finalize_sampler_kwargs(self, sampler_kwargs):
sampler_kwargs["maxcall"] = self.sampler.ncall + self.n_check_point

def _add_live(self):
pass

def _remove_live(self):
pass

def read_saved_state(self, continuing=False):
resume = super(DynamicDynesty, self).read_saved_state(continuing=continuing)
if not resume:
Expand Down
32 changes: 23 additions & 9 deletions bilby/core/sampler/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,16 +667,19 @@ def _run_external_sampler_without_checkpointing(self):

def finalize_sampler_kwargs(self, sampler_kwargs):
sampler_kwargs["maxcall"] = self.n_check_point
sampler_kwargs["add_live"] = True
sampler_kwargs["add_live"] = False

def _run_external_sampler_with_checkpointing(self):
"""
In order to access the checkpointing, we run the sampler for short
periods of time (less than the checkpoint time) and if sufficient
time has passed, write a checkpoint before continuing. To get the most
informative checkpoint plots, the current live points are added to the
chain of nested samples within dynesty and have to be removed before
restarting the sampler.
chain of nested samples before making the plots and have to be removed
before restarting the sampler. We previously used the dynesty internal
version of this, but this is unsafe as dynesty is not capable of
determining if adding the live points was interrupted and so we want to
minimize the number of times this is done.
"""

logger.debug("Running sampler with checkpointing")
Expand All @@ -691,8 +694,7 @@ def _run_external_sampler_with_checkpointing(self):
)
while True:
self.finalize_sampler_kwargs(sampler_kwargs)
if getattr(self.sampler, "added_live", False):
self.sampler._remove_live_points()
self._remove_live()
self.sampler.run_nested(**sampler_kwargs)
if self.sampler.ncall == old_ncall:
break
Expand All @@ -706,15 +708,27 @@ def _run_external_sampler_with_checkpointing(self):
).total_seconds()
if last_checkpoint_s > self.check_point_delta_t:
self.write_current_state()
self._add_live()
self.plot_current_state()
if getattr(self.sampler, "added_live", False):
self.sampler._remove_live_points()
self._remove_live()

self._remove_live()
if "add_live" in sampler_kwargs:
sampler_kwargs["add_live"] = self.kwargs.get("add_live", True)
self.sampler.run_nested(**sampler_kwargs)
self.write_current_state()
self.plot_current_state()
return self.sampler.results

def _add_live(self):
if not self.sampler.added_live:
for _ in self.sampler.add_live_points():
pass

def _remove_live(self):
if self.sampler.added_live:
self.sampler._remove_live_points()

def _remove_checkpoint(self):
"""Remove checkpointed state"""
if os.path.isfile(self.resume_file):
Expand Down Expand Up @@ -774,8 +788,8 @@ def read_saved_state(self, continuing=False):
)
del sampler.versions
self.sampler = sampler
if getattr(self.sampler, "added_live", False) and continuing:
self.sampler._remove_live_points()
if continuing:
self._remove_live()
self.sampler.nqueue = -1
self.start_time = self.sampler.kwargs.pop("start_time")
self.sampling_time = self.sampler.kwargs.pop("sampling_time")
Expand Down
Loading