1
1
import functools
2
2
import logging
3
- import math
4
3
import time
5
4
import warnings
6
5
import weakref
@@ -680,7 +679,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
680
679
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
681
680
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
682
681
683
- This method does not remove any custom attributes added by user.
682
+ This method does not remove any custom attributs added by user.
684
683
685
684
Args:
686
685
state_dict: a dict with parameters
@@ -725,14 +724,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
725
724
726
725
@staticmethod
727
726
def _is_done (state : State ) -> bool :
728
- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
729
- is_done_count = (
730
- state .epoch_length is not None
731
- and state .max_epochs is not None
732
- and state .iteration >= state .epoch_length * state .max_epochs
733
- )
734
- is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
735
- return is_done_iters or is_done_count or is_done_epochs
727
+ return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
736
728
737
729
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
738
730
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -774,15 +766,14 @@ def run(
774
766
self ,
775
767
data : Optional [Iterable ] = None ,
776
768
max_epochs : Optional [int ] = None ,
777
- max_iters : Optional [int ] = None ,
778
769
epoch_length : Optional [int ] = None ,
779
770
seed : Optional [int ] = None ,
780
771
) -> State :
781
772
"""Runs the ``process_function`` over the passed data.
782
773
783
774
Engine has a state and the following logic is applied in this function:
784
775
785
- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, `seed`, if provided.
776
+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
786
777
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
787
778
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
788
779
provided, state is kept and used in the function.
@@ -800,8 +791,6 @@ def run(
800
791
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
801
792
determined as the iteration on which data iterator raises `StopIteration`.
802
793
This argument should not change if run is resuming from a state.
803
- max_iters: Number of iterations to run for.
804
- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
805
794
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
806
795
807
796
Returns:
@@ -860,6 +849,8 @@ def switch_batch(engine):
860
849
861
850
if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
862
851
# Create new state
852
+ if max_epochs is None :
853
+ max_epochs = 1
863
854
if epoch_length is None :
864
855
if data is None :
865
856
raise ValueError ("epoch_length should be provided if data is None" )
@@ -868,22 +859,9 @@ def switch_batch(engine):
868
859
if epoch_length is not None and epoch_length < 1 :
869
860
raise ValueError ("Input data has zero size. Please provide non-empty data" )
870
861
871
- if max_iters is None :
872
- if max_epochs is None :
873
- max_epochs = 1
874
- else :
875
- if max_epochs is not None :
876
- raise ValueError (
877
- "Arguments max_iters and max_epochs are mutually exclusive."
878
- "Please provide only max_epochs or max_iters."
879
- )
880
- if epoch_length is not None :
881
- max_epochs = math .ceil (max_iters / epoch_length )
882
-
883
862
self .state .iteration = 0
884
863
self .state .epoch = 0
885
864
self .state .max_epochs = max_epochs
886
- self .state .max_iters = max_iters
887
865
self .state .epoch_length = epoch_length
888
866
# Reset generator if previously used
889
867
self ._internal_run_generator = None
@@ -978,6 +956,7 @@ def _internal_run_as_gen(self) -> Generator:
978
956
self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
979
957
980
958
handlers_start_time = time .time ()
959
+
981
960
self ._fire_event (Events .EPOCH_COMPLETED )
982
961
epoch_time_taken += time .time () - handlers_start_time
983
962
# update time wrt handlers
@@ -1056,8 +1035,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1056
1035
if self .state .epoch_length is None :
1057
1036
# Define epoch length and stop the epoch
1058
1037
self .state .epoch_length = iter_counter
1059
- if self .state .max_iters is not None :
1060
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1061
1038
break
1062
1039
1063
1040
# Should exit while loop if we can not iterate
0 commit comments