Skip to content

Commit 2d11cd7

Browse files
committed
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
1 parent 3e67049 commit 2d11cd7

File tree

5 files changed

+9
-90
lines changed

5 files changed

+9
-90
lines changed

ignite/engine/engine.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import logging
3-
import math
43
import time
54
import warnings
65
import weakref
@@ -731,14 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731730

732731
@staticmethod
733732
def _is_done(state: State) -> bool:
734-
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
735733
is_done_count = (
736734
state.epoch_length is not None
737735
and state.max_epochs is not None
738736
and state.iteration >= state.epoch_length * state.max_epochs
739737
)
740738
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
741-
return is_done_iters or is_done_count or is_done_epochs
739+
return is_done_count or is_done_epochs
742740

743741
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
744742
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -780,14 +778,13 @@ def run(
780778
self,
781779
data: Optional[Iterable] = None,
782780
max_epochs: Optional[int] = None,
783-
max_iters: Optional[int] = None,
784781
epoch_length: Optional[int] = None,
785782
) -> State:
786783
"""Runs the ``process_function`` over the passed data.
787784
788785
Engine has a state and the following logic is applied in this function:
789786
790-
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
787+
- At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
791788
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
792789
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
793790
provided, state is kept and used in the function.
@@ -805,9 +802,6 @@ def run(
805802
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
806803
determined as the iteration on which data iterator raises `StopIteration`.
807804
This argument should not change if run is resuming from a state.
808-
max_iters: Number of iterations to run for.
809-
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
810-
811805
Returns:
812806
State: output state.
813807
@@ -858,6 +852,8 @@ def switch_batch(engine):
858852

859853
if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
860854
# Create new state
855+
if max_epochs is None:
856+
max_epochs = 1
861857
if epoch_length is None:
862858
if data is None:
863859
raise ValueError("epoch_length should be provided if data is None")
@@ -866,22 +862,9 @@ def switch_batch(engine):
866862
if epoch_length is not None and epoch_length < 1:
867863
raise ValueError("Input data has zero size. Please provide non-empty data")
868864

869-
if max_iters is None:
870-
if max_epochs is None:
871-
max_epochs = 1
872-
else:
873-
if max_epochs is not None:
874-
raise ValueError(
875-
"Arguments max_iters and max_epochs are mutually exclusive."
876-
"Please provide only max_epochs or max_iters."
877-
)
878-
if epoch_length is not None:
879-
max_epochs = math.ceil(max_iters / epoch_length)
880-
881865
self.state.iteration = 0
882866
self.state.epoch = 0
883867
self.state.max_epochs = max_epochs
884-
self.state.max_iters = max_iters
885868
self.state.epoch_length = epoch_length
886869
# Reset generator if previously used
887870
self._internal_run_generator = None
@@ -1062,18 +1045,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10621045
if self.state.epoch_length is None:
10631046
# Define epoch length and stop the epoch
10641047
self.state.epoch_length = iter_counter
1065-
if self.state.max_iters is not None:
1066-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
10671048
break
10681049

10691050
# Should exit while loop if we can not iterate
10701051
if should_exit:
1071-
if not self._is_done(self.state):
1072-
total_iters = (
1073-
self.state.epoch_length * self.state.max_epochs
1074-
if self.state.max_epochs is not None
1075-
else self.state.max_iters
1076-
)
1052+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1053+
total_iters = self.state.epoch_length * self.state.max_epochs
10771054

10781055
warnings.warn(
10791056
"Data iterator can not provide data anymore but required total number of "
@@ -1104,10 +1081,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11041081
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
11051082
break
11061083

1107-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1108-
self.should_terminate = True
1109-
raise _EngineTerminateException()
1110-
11111084
except _EngineTerminateSingleEpochException:
11121085
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
11131086
self.should_terminate_single_epoch = False
@@ -1229,18 +1202,12 @@ def _run_once_on_dataset_legacy(self) -> float:
12291202
if self.state.epoch_length is None:
12301203
# Define epoch length and stop the epoch
12311204
self.state.epoch_length = iter_counter
1232-
if self.state.max_iters is not None:
1233-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
12341205
break
12351206

12361207
# Should exit while loop if we can not iterate
12371208
if should_exit:
1238-
if not self._is_done(self.state):
1239-
total_iters = (
1240-
self.state.epoch_length * self.state.max_epochs
1241-
if self.state.max_epochs is not None
1242-
else self.state.max_iters
1243-
)
1209+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1210+
total_iters = self.state.epoch_length * self.state.max_epochs
12441211

12451212
warnings.warn(
12461213
"Data iterator can not provide data anymore but required total number of "
@@ -1271,10 +1238,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12711238
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
12721239
break
12731240

1274-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1275-
self.should_terminate = True
1276-
raise _EngineTerminateException()
1277-
12781241
except _EngineTerminateSingleEpochException:
12791242
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
12801243
self.should_terminate_single_epoch = False

ignite/engine/events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,6 @@ class State:
443443
state.dataloader # data passed to engine
444444
state.epoch_length # optional length of an epoch
445445
state.max_epochs # number of epochs to run
446-
state.max_iters # number of iterations to run
447446
state.batch # batch passed to `process_function`
448447
state.output # output of `process_function` after a single iteration
449448
state.metrics # dictionary with defined metrics if any
@@ -470,7 +469,6 @@ def __init__(self, **kwargs: Any) -> None:
470469
self.epoch = 0
471470
self.epoch_length: Optional[int] = None
472471
self.max_epochs: Optional[int] = None
473-
self.max_iters: Optional[int] = None
474472
self.output: Optional[int] = None
475473
self.batch: Optional[int] = None
476474
self.metrics: Dict[str, Any] = {}

ignite/handlers/lr_finder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _run(
105105
max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
106106
if max_iter < num_iter:
107107
max_iter = num_iter
108-
trainer.state.max_iters = num_iter
109108
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
110109

111110
if not trainer.has_event_handler(self._reached_num_iterations):

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,47 +1026,6 @@ def switch_dataloader():
10261026

10271027
trainer.run(data1, max_epochs=10)
10281028

1029-
def test_run_with_max_iters(self):
1030-
max_iters = 8
1031-
engine = Engine(lambda e, b: 1)
1032-
engine.run([0] * 20, max_iters=max_iters)
1033-
assert engine.state.iteration == max_iters
1034-
assert engine.state.max_iters == max_iters
1035-
1036-
def test_run_with_max_iters_greater_than_epoch_length(self):
1037-
max_iters = 73
1038-
engine = Engine(lambda e, b: 1)
1039-
engine.run([0] * 20, max_iters=max_iters)
1040-
assert engine.state.iteration == max_iters
1041-
1042-
def test_run_with_invalid_max_iters_and_max_epoch(self):
1043-
max_iters = 12
1044-
max_epochs = 2
1045-
engine = Engine(lambda e, b: 1)
1046-
with pytest.raises(
1047-
ValueError,
1048-
match=r"Arguments max_iters and max_epochs are mutually exclusive."
1049-
"Please provide only max_epochs or max_iters.",
1050-
):
1051-
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
1052-
1053-
def test_epoch_events_fired_max_iters(self):
1054-
max_iters = 32
1055-
engine = Engine(lambda e, b: 1)
1056-
1057-
@engine.on(Events.EPOCH_COMPLETED)
1058-
def fired_event(engine):
1059-
assert engine.state.iteration % engine.state.epoch_length == 0
1060-
1061-
engine.run([0] * 10, max_iters=max_iters)
1062-
1063-
def test_is_done_with_max_iters(self):
1064-
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1065-
assert not Engine._is_done(state)
1066-
1067-
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1068-
assert Engine._is_done(state)
1069-
10701029
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
10711030
def test_batch_is_released_before_new_one_is_loaded_on_cuda(self):
10721031
torch.cuda.empty_cache()

tests/ignite/handlers/test_lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader):
357357
trainer_with_finder.run(dataloader)
358358
assert_output_sizes(lr_finder, dummy_engine)
359359
assert dummy_engine.state.iteration != len(dataloader)
360-
assert dummy_engine.state.iteration == 150
360+
assert dummy_engine.state.iteration == 150 + 1
361361

362362

363363
def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):

0 commit comments

Comments
 (0)