Skip to content

Commit c7348f7

Browse files
committed
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
This reverts commit 307ac11.
1 parent 7b3f1d2 commit c7348f7

File tree

2 files changed

+6
-31
lines changed

2 files changed

+6
-31
lines changed

ignite/engine/engine.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import logging
3-
import math
43
import time
54
import warnings
65
import weakref
@@ -680,7 +679,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
680679
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
681680
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
682681
683-
This method does not remove any custom attributes added by user.
682+
This method does not remove any custom attributs added by user.
684683
685684
Args:
686685
state_dict: a dict with parameters
@@ -725,14 +724,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
725724

726725
@staticmethod
727726
def _is_done(state: State) -> bool:
728-
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
729-
is_done_count = (
730-
state.epoch_length is not None
731-
and state.max_epochs is not None
732-
and state.iteration >= state.epoch_length * state.max_epochs
733-
)
734-
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
735-
return is_done_iters or is_done_count or is_done_epochs
727+
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
736728

737729
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
738730
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -774,15 +766,14 @@ def run(
774766
self,
775767
data: Optional[Iterable] = None,
776768
max_epochs: Optional[int] = None,
777-
max_iters: Optional[int] = None,
778769
epoch_length: Optional[int] = None,
779770
seed: Optional[int] = None,
780771
) -> State:
781772
"""Runs the ``process_function`` over the passed data.
782773
783774
Engine has a state and the following logic is applied in this function:
784775
785-
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided.
776+
- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
786777
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
787778
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
788779
provided, state is kept and used in the function.
@@ -800,8 +791,6 @@ def run(
800791
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
801792
determined as the iteration on which data iterator raises `StopIteration`.
802793
This argument should not change if run is resuming from a state.
803-
max_iters: Number of iterations to run for.
804-
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
805794
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
806795
807796
Returns:
@@ -860,6 +849,8 @@ def switch_batch(engine):
860849

861850
if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
862851
# Create new state
852+
if max_epochs is None:
853+
max_epochs = 1
863854
if epoch_length is None:
864855
if data is None:
865856
raise ValueError("epoch_length should be provided if data is None")
@@ -868,22 +859,9 @@ def switch_batch(engine):
868859
if epoch_length is not None and epoch_length < 1:
869860
raise ValueError("Input data has zero size. Please provide non-empty data")
870861

871-
if max_iters is None:
872-
if max_epochs is None:
873-
max_epochs = 1
874-
else:
875-
if max_epochs is not None:
876-
raise ValueError(
877-
"Arguments max_iters and max_epochs are mutually exclusive."
878-
"Please provide only max_epochs or max_iters."
879-
)
880-
if epoch_length is not None:
881-
max_epochs = math.ceil(max_iters / epoch_length)
882-
883862
self.state.iteration = 0
884863
self.state.epoch = 0
885864
self.state.max_epochs = max_epochs
886-
self.state.max_iters = max_iters
887865
self.state.epoch_length = epoch_length
888866
# Reset generator if previously used
889867
self._internal_run_generator = None
@@ -978,6 +956,7 @@ def _internal_run_as_gen(self) -> Generator:
978956
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
979957

980958
handlers_start_time = time.time()
959+
981960
self._fire_event(Events.EPOCH_COMPLETED)
982961
epoch_time_taken += time.time() - handlers_start_time
983962
# update time wrt handlers
@@ -1056,8 +1035,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10561035
if self.state.epoch_length is None:
10571036
# Define epoch length and stop the epoch
10581037
self.state.epoch_length = iter_counter
1059-
if self.state.max_iters is not None:
1060-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
10611038
break
10621039

10631040
# Should exit while loop if we can not iterate

ignite/engine/events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,6 @@ class State:
451451
state.dataloader # data passed to engine
452452
state.epoch_length # optional length of an epoch
453453
state.max_epochs # number of epochs to run
454-
state.max_iters # number of iterations to run
455454
state.batch # batch passed to `process_function`
456455
state.output # output of `process_function` after a single iteration
457456
state.metrics # dictionary with defined metrics if any
@@ -478,7 +477,6 @@ def __init__(self, **kwargs: Any) -> None:
478477
self.epoch = 0
479478
self.epoch_length: Optional[int] = None
480479
self.max_epochs: Optional[int] = None
481-
self.max_iters: Optional[int] = None
482480
self.output: Optional[int] = None
483481
self.batch: Optional[int] = None
484482
self.metrics: Dict[str, Any] = {}

0 commit comments

Comments
 (0)