1
1
import functools
2
2
import logging
3
- import math
4
3
import time
5
4
import warnings
6
5
import weakref
@@ -508,7 +507,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
508
507
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
509
508
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
510
509
511
- This method does not remove any custom attributes added by user.
510
+ This method does not remove any custom attributs added by user.
512
511
513
512
Args:
514
513
state_dict: a dict with parameters
@@ -553,14 +552,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
553
552
554
553
@staticmethod
555
554
def _is_done (state : State ) -> bool :
556
- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
557
- is_done_count = (
558
- state .epoch_length is not None
559
- and state .max_epochs is not None
560
- and state .iteration >= state .epoch_length * state .max_epochs
561
- )
562
- is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
563
- return is_done_iters or is_done_count or is_done_epochs
555
+ return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
564
556
565
557
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
566
558
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -602,15 +594,14 @@ def run(
602
594
self ,
603
595
data : Optional [Iterable ] = None ,
604
596
max_epochs : Optional [int ] = None ,
605
- max_iters : Optional [int ] = None ,
606
597
epoch_length : Optional [int ] = None ,
607
598
seed : Optional [int ] = None ,
608
599
) -> State :
609
600
"""Runs the ``process_function`` over the passed data.
610
601
611
602
Engine has a state and the following logic is applied in this function:
612
603
613
- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, `seed`, if provided.
604
+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
614
605
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
615
606
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
616
607
provided, state is kept and used in the function.
@@ -628,8 +619,6 @@ def run(
628
619
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
629
620
determined as the iteration on which data iterator raises `StopIteration`.
630
621
This argument should not change if run is resuming from a state.
631
- max_iters: Number of iterations to run for.
632
- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
633
622
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
634
623
635
624
Returns:
@@ -688,6 +677,8 @@ def switch_batch(engine):
688
677
689
678
if self .state .max_epochs is None or self ._is_done (self .state ):
690
679
# Create new state
680
+ if max_epochs is None :
681
+ max_epochs = 1
691
682
if epoch_length is None :
692
683
if data is None :
693
684
raise ValueError ("epoch_length should be provided if data is None" )
@@ -696,22 +687,9 @@ def switch_batch(engine):
696
687
if epoch_length is not None and epoch_length < 1 :
697
688
raise ValueError ("Input data has zero size. Please provide non-empty data" )
698
689
699
- if max_iters is None :
700
- if max_epochs is None :
701
- max_epochs = 1
702
- else :
703
- if max_epochs is not None :
704
- raise ValueError (
705
- "Arguments max_iters and max_epochs are mutually exclusive."
706
- "Please provide only max_epochs or max_iters."
707
- )
708
- if epoch_length is not None :
709
- max_epochs = math .ceil (max_iters / epoch_length )
710
-
711
690
self .state .iteration = 0
712
691
self .state .epoch = 0
713
692
self .state .max_epochs = max_epochs
714
- self .state .max_iters = max_iters
715
693
self .state .epoch_length = epoch_length
716
694
self .logger .info (f"Engine run starting with max_epochs={ max_epochs } ." )
717
695
else :
@@ -765,7 +743,7 @@ def _internal_run(self) -> State:
765
743
try :
766
744
start_time = time .time ()
767
745
self ._fire_event (Events .STARTED )
768
- while not self ._is_done ( self .state ) and not self .should_terminate :
746
+ while self .state . epoch < self .state . max_epochs and not self .should_terminate : # type: ignore[operator]
769
747
self .state .epoch += 1
770
748
self ._fire_event (Events .EPOCH_STARTED )
771
749
@@ -835,8 +813,6 @@ def _run_once_on_dataset(self) -> float:
835
813
if self .state .epoch_length is None :
836
814
# Define epoch length and stop the epoch
837
815
self .state .epoch_length = iter_counter
838
- if self .state .max_iters is not None :
839
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
840
816
break
841
817
842
818
# Should exit while loop if we can not iterate
@@ -876,10 +852,6 @@ def _run_once_on_dataset(self) -> float:
876
852
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
877
853
break
878
854
879
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
880
- self .should_terminate = True
881
- break
882
-
883
855
except Exception as e :
884
856
self .logger .error (f"Current run is terminating due to exception: { e } " )
885
857
self ._handle_exception (e )
0 commit comments