From 26f7cec873b382a500204aa0121bff1661248462 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 29 Aug 2022 14:08:48 +0200 Subject: [PATCH] Fixing engine terminate behaviour when resumed (#2678) --- ignite/engine/engine.py | 25 ++++++++++++++++++------- tests/ignite/engine/test_engine.py | 2 +- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 938b7133cb4..d583747c57d 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -310,7 +310,7 @@ def execute_something(): except ValueError: _check_signature(handler, "handler", *(event_args + args), **kwargs) self._event_handlers[event_name].append((handler, args, kwargs)) - self.logger.debug(f"added handler for event {event_name}") + self.logger.debug(f"Added handler for event {event_name}") return RemovableEventHandle(event_name, handler, self) @@ -406,7 +406,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) -> **event_kwargs: optional keyword args to be passed to all handlers. """ - self.logger.debug(f"firing handlers for event {event_name}") + self.logger.debug(f"{self.state.epoch} | {self.state.iteration}, Firing handlers for event {event_name}") self.last_event_name = event_name for func, args, kwargs in self._event_handlers[event_name]: kwargs.update(event_kwargs) @@ -720,6 +720,11 @@ def switch_batch(engine): if self.state.epoch_length is None and data is None: raise ValueError("epoch_length should be provided if data is None") + if self.should_terminate: + # If engine was terminated and now is resuming from terminated state + # we need to initialize iter_counter as 0 + self._init_iter.append(0) + self.state.dataloader = data return self._internal_run() @@ -750,12 +755,13 @@ def _setup_dataloader_iter(self) -> None: def _setup_engine(self) -> None: self._setup_dataloader_iter() - iteration = self.state.iteration - # Below we define initial counter value for _run_once_on_dataset to measure a single epoch - if self.state.epoch_length is not None: - iteration %= self.state.epoch_length - self._init_iter.append(iteration) + if len(self._init_iter) == 0: + iteration = self.state.iteration + # Below we define initial counter value for _run_once_on_dataset to measure a single epoch + if self.state.epoch_length is not None: + iteration %= self.state.epoch_length + self._init_iter.append(iteration) def _internal_run(self) -> State: self.should_terminate = self.should_terminate_single_epoch = False @@ -826,6 +832,11 @@ def _run_once_on_dataset(self) -> float: start_time = time.time() # We need to setup iter_counter > 0 if we resume from an iteration + if len(self._init_iter) > 1: + raise RuntimeError( + "Internal error, len(self._init_iter) should 0 or 1, " + f"but got: {len(self._init_iter)}, {self._init_iter}" + ) iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0 should_exit = False try: diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index d046ed5c675..2f18c3e720d 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -206,7 +206,7 @@ def check_iter_and_data(): assert state.epoch == max_epochs assert not engine.should_terminate - assert state.iteration == real_epoch_length * (max_epochs - 1) + assert state.iteration == real_epoch_length * (max_epochs - 1) + (iteration_to_stop % real_epoch_length) class RecordedEngine(Engine):