1
1
import functools
2
2
import logging
3
- import math
4
3
import time
5
4
import warnings
6
5
import weakref
@@ -731,14 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731
730
732
731
@staticmethod
733
732
def _is_done (state : State ) -> bool :
734
- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
735
733
is_done_count = (
736
734
state .epoch_length is not None
737
735
and state .max_epochs is not None
738
736
and state .iteration >= state .epoch_length * state .max_epochs
739
737
)
740
738
is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
741
- return is_done_iters or is_done_count or is_done_epochs
739
+ return is_done_count or is_done_epochs
742
740
743
741
def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
744
742
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -780,14 +778,13 @@ def run(
780
778
self ,
781
779
data : Optional [Iterable ] = None ,
782
780
max_epochs : Optional [int ] = None ,
783
- max_iters : Optional [int ] = None ,
784
781
epoch_length : Optional [int ] = None ,
785
782
) -> State :
786
783
"""Runs the ``process_function`` over the passed data.
787
784
788
785
Engine has a state and the following logic is applied in this function:
789
786
790
- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
787
+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
791
788
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
792
789
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
793
790
provided, state is kept and used in the function.
@@ -805,9 +802,6 @@ def run(
805
802
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
806
803
determined as the iteration on which data iterator raises `StopIteration`.
807
804
This argument should not change if run is resuming from a state.
808
- max_iters: Number of iterations to run for.
809
- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
810
-
811
805
Returns:
812
806
State: output state.
813
807
@@ -858,6 +852,8 @@ def switch_batch(engine):
858
852
859
853
if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
860
854
# Create new state
855
+ if max_epochs is None :
856
+ max_epochs = 1
861
857
if epoch_length is None :
862
858
if data is None :
863
859
raise ValueError ("epoch_length should be provided if data is None" )
@@ -866,22 +862,9 @@ def switch_batch(engine):
866
862
if epoch_length is not None and epoch_length < 1 :
867
863
raise ValueError ("Input data has zero size. Please provide non-empty data" )
868
864
869
- if max_iters is None :
870
- if max_epochs is None :
871
- max_epochs = 1
872
- else :
873
- if max_epochs is not None :
874
- raise ValueError (
875
- "Arguments max_iters and max_epochs are mutually exclusive."
876
- "Please provide only max_epochs or max_iters."
877
- )
878
- if epoch_length is not None :
879
- max_epochs = math .ceil (max_iters / epoch_length )
880
-
881
865
self .state .iteration = 0
882
866
self .state .epoch = 0
883
867
self .state .max_epochs = max_epochs
884
- self .state .max_iters = max_iters
885
868
self .state .epoch_length = epoch_length
886
869
# Reset generator if previously used
887
870
self ._internal_run_generator = None
@@ -1062,18 +1045,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1062
1045
if self .state .epoch_length is None :
1063
1046
# Define epoch length and stop the epoch
1064
1047
self .state .epoch_length = iter_counter
1065
- if self .state .max_iters is not None :
1066
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1067
1048
break
1068
1049
1069
1050
# Should exit while loop if we can not iterate
1070
1051
if should_exit :
1071
- if not self ._is_done (self .state ):
1072
- total_iters = (
1073
- self .state .epoch_length * self .state .max_epochs
1074
- if self .state .max_epochs is not None
1075
- else self .state .max_iters
1076
- )
1052
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1053
+ total_iters = self .state .epoch_length * self .state .max_epochs
1077
1054
1078
1055
warnings .warn (
1079
1056
"Data iterator can not provide data anymore but required total number of "
@@ -1104,10 +1081,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
1104
1081
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1105
1082
break
1106
1083
1107
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1108
- self .should_terminate = True
1109
- raise _EngineTerminateException ()
1110
-
1111
1084
except _EngineTerminateSingleEpochException :
1112
1085
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1113
1086
self .should_terminate_single_epoch = False
@@ -1229,18 +1202,12 @@ def _run_once_on_dataset_legacy(self) -> float:
1229
1202
if self .state .epoch_length is None :
1230
1203
# Define epoch length and stop the epoch
1231
1204
self .state .epoch_length = iter_counter
1232
- if self .state .max_iters is not None :
1233
- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
1234
1205
break
1235
1206
1236
1207
# Should exit while loop if we can not iterate
1237
1208
if should_exit :
1238
- if not self ._is_done (self .state ):
1239
- total_iters = (
1240
- self .state .epoch_length * self .state .max_epochs
1241
- if self .state .max_epochs is not None
1242
- else self .state .max_iters
1243
- )
1209
+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1210
+ total_iters = self .state .epoch_length * self .state .max_epochs
1244
1211
1245
1212
warnings .warn (
1246
1213
"Data iterator can not provide data anymore but required total number of "
@@ -1271,10 +1238,6 @@ def _run_once_on_dataset_legacy(self) -> float:
1271
1238
if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
1272
1239
break
1273
1240
1274
- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1275
- self .should_terminate = True
1276
- raise _EngineTerminateException ()
1277
-
1278
1241
except _EngineTerminateSingleEpochException :
1279
1242
self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
1280
1243
self .should_terminate_single_epoch = False
0 commit comments