Skip to content

Commit 71b72c6

Browse files
committed
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
This reverts commit 307ac11.
1 parent 2ed480c commit 71b72c6

File tree

3 files changed

+6
-139
lines changed

3 files changed

+6
-139
lines changed

ignite/engine/engine.py

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import logging
3-
import math
43
import time
54
import warnings
65
import weakref
@@ -508,7 +507,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
508507
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
509508
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
510509
511-
This method does not remove any custom attributes added by user.
510+
This method does not remove any custom attributs added by user.
512511
513512
Args:
514513
state_dict: a dict with parameters
@@ -553,14 +552,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
553552

554553
@staticmethod
555554
def _is_done(state: State) -> bool:
556-
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
557-
is_done_count = (
558-
state.epoch_length is not None
559-
and state.max_epochs is not None
560-
and state.iteration >= state.epoch_length * state.max_epochs
561-
)
562-
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
563-
return is_done_iters or is_done_count or is_done_epochs
555+
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
564556

565557
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
566558
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -602,15 +594,14 @@ def run(
602594
self,
603595
data: Optional[Iterable] = None,
604596
max_epochs: Optional[int] = None,
605-
max_iters: Optional[int] = None,
606597
epoch_length: Optional[int] = None,
607598
seed: Optional[int] = None,
608599
) -> State:
609600
"""Runs the ``process_function`` over the passed data.
610601
611602
Engine has a state and the following logic is applied in this function:
612603
613-
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided.
604+
- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
614605
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
615606
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
616607
provided, state is kept and used in the function.
@@ -628,8 +619,6 @@ def run(
628619
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
629620
determined as the iteration on which data iterator raises `StopIteration`.
630621
This argument should not change if run is resuming from a state.
631-
max_iters: Number of iterations to run for.
632-
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
633622
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
634623
635624
Returns:
@@ -688,6 +677,8 @@ def switch_batch(engine):
688677

689678
if self.state.max_epochs is None or self._is_done(self.state):
690679
# Create new state
680+
if max_epochs is None:
681+
max_epochs = 1
691682
if epoch_length is None:
692683
if data is None:
693684
raise ValueError("epoch_length should be provided if data is None")
@@ -696,22 +687,9 @@ def switch_batch(engine):
696687
if epoch_length is not None and epoch_length < 1:
697688
raise ValueError("Input data has zero size. Please provide non-empty data")
698689

699-
if max_iters is None:
700-
if max_epochs is None:
701-
max_epochs = 1
702-
else:
703-
if max_epochs is not None:
704-
raise ValueError(
705-
"Arguments max_iters and max_epochs are mutually exclusive."
706-
"Please provide only max_epochs or max_iters."
707-
)
708-
if epoch_length is not None:
709-
max_epochs = math.ceil(max_iters / epoch_length)
710-
711690
self.state.iteration = 0
712691
self.state.epoch = 0
713692
self.state.max_epochs = max_epochs
714-
self.state.max_iters = max_iters
715693
self.state.epoch_length = epoch_length
716694
self.logger.info(f"Engine run starting with max_epochs={max_epochs}.")
717695
else:
@@ -765,7 +743,7 @@ def _internal_run(self) -> State:
765743
try:
766744
start_time = time.time()
767745
self._fire_event(Events.STARTED)
768-
while not self._is_done(self.state) and not self.should_terminate:
746+
while self.state.epoch < self.state.max_epochs and not self.should_terminate: # type: ignore[operator]
769747
self.state.epoch += 1
770748
self._fire_event(Events.EPOCH_STARTED)
771749

@@ -835,8 +813,6 @@ def _run_once_on_dataset(self) -> float:
835813
if self.state.epoch_length is None:
836814
# Define epoch length and stop the epoch
837815
self.state.epoch_length = iter_counter
838-
if self.state.max_iters is not None:
839-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
840816
break
841817

842818
# Should exit while loop if we can not iterate
@@ -876,10 +852,6 @@ def _run_once_on_dataset(self) -> float:
876852
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
877853
break
878854

879-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
880-
self.should_terminate = True
881-
break
882-
883855
except Exception as e:
884856
self.logger.error(f"Current run is terminating due to exception: {e}")
885857
self._handle_exception(e)

ignite/engine/events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,6 @@ class State:
375375
state.dataloader # data passed to engine
376376
state.epoch_length # optional length of an epoch
377377
state.max_epochs # number of epochs to run
378-
state.max_iters # number of iterations to run
379378
state.batch # batch passed to `process_function`
380379
state.output # output of `process_function` after a single iteration
381380
state.metrics # dictionary with defined metrics if any
@@ -402,7 +401,6 @@ def __init__(self, **kwargs: Any) -> None:
402401
self.epoch = 0
403402
self.epoch_length = None # type: Optional[int]
404403
self.max_epochs = None # type: Optional[int]
405-
self.max_iters = None # type: Optional[int]
406404
self.output = None # type: Optional[int]
407405
self.batch = None # type: Optional[int]
408406
self.metrics = {} # type: Dict[str, Any]

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -894,109 +894,6 @@ def switch_dataloader():
894894
trainer.run(data1, max_epochs=10)
895895

896896

897-
def test_run_with_max_iters():
898-
max_iters = 8
899-
engine = Engine(lambda e, b: 1)
900-
engine.run([0] * 20, max_iters=max_iters)
901-
assert engine.state.iteration == max_iters
902-
assert engine.state.max_iters == max_iters
903-
904-
905-
def test_run_with_max_iters_greater_than_epoch_length():
906-
max_iters = 73
907-
engine = Engine(lambda e, b: 1)
908-
engine.run([0] * 20, max_iters=max_iters)
909-
assert engine.state.iteration == max_iters
910-
911-
912-
def test_run_with_invalid_max_iters_and_max_epoch():
913-
max_iters = 12
914-
max_epochs = 2
915-
engine = Engine(lambda e, b: 1)
916-
with pytest.raises(
917-
ValueError,
918-
match=r"Arguments max_iters and max_epochs are mutually exclusive."
919-
"Please provide only max_epochs or max_iters.",
920-
):
921-
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
922-
923-
924-
def test_epoch_events_fired():
925-
max_iters = 32
926-
engine = Engine(lambda e, b: 1)
927-
928-
@engine.on(Events.EPOCH_COMPLETED)
929-
def fired_event(engine):
930-
assert engine.state.iteration % engine.state.epoch_length == 0
931-
932-
engine.run([0] * 10, max_iters=max_iters)
933-
934-
935-
def test_is_done_with_max_iters():
936-
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
937-
assert not Engine._is_done(state)
938-
939-
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
940-
assert Engine._is_done(state)
941-
942-
943-
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
944-
def test_batch_is_released_before_new_one_is_loaded_on_cuda():
945-
torch.cuda.empty_cache()
946-
947-
engine = Engine(lambda e, b: None)
948-
949-
def _test():
950-
mem_consumption = []
951-
952-
def dataloader():
953-
for _ in range(4):
954-
mem_consumption.append(torch.cuda.memory_allocated())
955-
batch = torch.randn(10).cuda()
956-
mem_consumption.append(torch.cuda.memory_allocated())
957-
yield batch
958-
959-
engine.run(dataloader(), max_epochs=2, epoch_length=2)
960-
return mem_consumption
961-
962-
mem_consumption1 = _test()
963-
# mem_consumption should look like [0, 512, 512, 512, 512, 512, 512, 512]
964-
assert len(set(mem_consumption1[1:])) == 1
965-
966-
mem_consumption2 = _test()
967-
assert len(set(mem_consumption2[1:])) == 1
968-
969-
assert mem_consumption1 == mem_consumption2
970-
971-
972-
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
973-
def test_output_is_released_before_new_one_is_assigned_on_cuda():
974-
torch.cuda.empty_cache()
975-
976-
def _test():
977-
mem_consumption = []
978-
979-
def update_fn(engine, batch):
980-
mem_consumption.append(torch.cuda.memory_allocated())
981-
output = torch.rand(10).cuda()
982-
mem_consumption.append(torch.cuda.memory_allocated())
983-
return output
984-
985-
engine = Engine(update_fn)
986-
engine.run([0, 1], max_epochs=2)
987-
988-
return mem_consumption
989-
990-
mem_consumption1 = _test()
991-
# mem_consumption ~ [0, 512, 0, 512, 0, 512, 0, 512]
992-
assert len(set(mem_consumption1)) == 2
993-
994-
mem_consumption2 = _test()
995-
assert len(set(mem_consumption2)) == 2
996-
997-
assert mem_consumption1 == mem_consumption2
998-
999-
1000897
def test_engine_no_data_asserts():
1001898
trainer = Engine(lambda e, b: None)
1002899

0 commit comments

Comments
 (0)