Skip to content

Commit 2f0428b

Browse files
authored
Merge branch 'master' into device-index-warning
2 parents 90a4d02 + a6664de commit 2f0428b

File tree

3 files changed

+34
-1
lines changed

3 files changed

+34
-1
lines changed

ignite/engine/engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,18 @@ def run(
639639
def switch_batch(engine):
640640
engine.state.batch = preprocess_batch(engine.state.batch)
641641
642+
Restart the training from the beginning. User can reset `max_epochs = None` or either call
643+
`trainer.state.restart()`:
644+
645+
.. code-block:: python
646+
647+
# ...
648+
trainer.run(train_loader, max_epochs=5)
649+
650+
# Reset model weights etc. and restart the training
651+
trainer.state.restart() # equivalent to trainer.state.max_epochs = None
652+
trainer.run(train_loader, max_epochs=2)
653+
642654
"""
643655
if seed is not None:
644656
warnings.warn(
@@ -655,7 +667,10 @@ def switch_batch(engine):
655667
if max_epochs < self.state.epoch:
656668
raise ValueError(
657669
"Argument max_epochs should be larger than the start epoch "
658-
"defined in the state: {} vs {}".format(max_epochs, self.state.epoch)
670+
"defined in the state: {} vs {}. Please, call state.restart() "
671+
"before calling engine.run() in order to restart the training from the beginning.".format(
672+
max_epochs, self.state.epoch
673+
)
659674
)
660675
self.state.max_epochs = max_epochs
661676
if epoch_length is not None:

ignite/engine/events.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,9 @@ def get_event_attrib_value(self, event_name: Union[CallableEventWithFilter, Enum
341341
raise RuntimeError("Unknown event name '{}'".format(event_name))
342342
return getattr(self, State.event_to_attr[event_name])
343343

344+
def restart(self) -> None:
345+
self.max_epochs = None
346+
344347
def __repr__(self) -> str:
345348
s = "State:\n"
346349
for attr, value in self.__dict__.items():

tests/ignite/engine/test_engine_state_dict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,3 +267,18 @@ def check_custom_attr():
267267

268268
_test()
269269
_test(with_load_state_dict=True)
270+
271+
272+
def test_restart_training():
273+
data = range(10)
274+
engine = Engine(lambda e, b: 1)
275+
state = engine.run(data, max_epochs=5)
276+
with pytest.raises(
277+
ValueError,
278+
match=r"Argument max_epochs should be larger than the start epoch defined in the state: 2 vs 5. "
279+
r"Please, call state.restart\(\) "
280+
r"before calling engine.run\(\) in order to restart the training from the beginning.",
281+
):
282+
state = engine.run(data, max_epochs=2)
283+
state.restart()
284+
engine.run(data, max_epochs=2)

0 commit comments

Comments
 (0)