diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 2679e97ccc8f..22406badb2fd 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -695,25 +695,32 @@ def _internal_run(self) -> State: self._setup_engine() time_taken = self._run_once_on_dataset() + # time is available for handlers but must be update after fire + self.state.times[Events.EPOCH_COMPLETED.name] = time_taken + handlers_start_time = time.time() + if self.should_terminate: + self._fire_event(Events.TERMINATE) + else: + self._fire_event(Events.EPOCH_COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers self.state.times[Events.EPOCH_COMPLETED.name] = time_taken hours, mins, secs = _to_hours_mins_secs(time_taken) - elapsed_time_message = "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % ( - self.state.epoch, - hours, - mins, - secs, + self.logger.info( + "Epoch[%s] Complete. Time taken: %02d:%02d:%02d" % (self.state.epoch, hours, mins, secs) ) if self.should_terminate: - self._fire_event(Events.TERMINATE) - self.logger.info(elapsed_time_message) break - self._fire_event(Events.EPOCH_COMPLETED) - self.logger.info(elapsed_time_message) time_taken = time.time() - start_time - hours, mins, secs = _to_hours_mins_secs(time_taken) + # time is available for handlers but must be update after fire self.state.times[Events.COMPLETED.name] = time_taken + handlers_start_time = time.time() self._fire_event(Events.COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) self.logger.info("Engine run complete. Time taken: %02d:%02d:%02d" % (hours, mins, secs)) except BaseException as e: diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 7c1f0e5d0455..8c2d966b11db 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -377,19 +377,29 @@ def test_state_get_event_attrib_value(): def test_time_stored_in_state(): def _test(data, max_epochs, epoch_length): sleep_time = 0.01 + extra_sleep_time = 0.1 engine = Engine(lambda e, b: time.sleep(sleep_time)) - def check_epoch_time(engine): + @engine.on(Events.EPOCH_COMPLETED) + def check_epoch_time(): assert engine.state.times[Events.EPOCH_COMPLETED.name] >= sleep_time * epoch_length + time.sleep(extra_sleep_time) - def check_completed_time(engine): - assert engine.state.times[Events.COMPLETED.name] >= sleep_time * epoch_length * max_epochs - - engine.add_event_handler(Events.EPOCH_COMPLETED, lambda e: check_epoch_time(e)) - engine.add_event_handler(Events.COMPLETED, lambda e: check_completed_time(e)) + @engine.on(Events.COMPLETED) + def check_completed_time(): + assert ( + engine.state.times[Events.COMPLETED.name] >= (sleep_time * epoch_length + extra_sleep_time) * max_epochs + ) + time.sleep(extra_sleep_time) engine.run(data, max_epochs=max_epochs, epoch_length=epoch_length) + assert engine.state.times[Events.EPOCH_COMPLETED.name] >= sleep_time * epoch_length + extra_sleep_time + assert ( + engine.state.times[Events.COMPLETED.name] + >= (sleep_time * epoch_length + extra_sleep_time) * max_epochs + extra_sleep_time + ) + _test(list(range(100)), max_epochs=2, epoch_length=100) _test(list(range(200)), max_epochs=2, epoch_length=100) _test(list(range(200)), max_epochs=5, epoch_length=100)