Skip to content

Commit 96689c5

Browse files
authored
Activate MyPy in ignite.contrib.engines (#1416)
* Activate mypy in ignite.contrib.engines * Fix review comments * fix extra event too * Update to fix strict errors
1 parent ea086e1 commit 96689c5

File tree

4 files changed

+78
-55
lines changed

4 files changed

+78
-55
lines changed

ignite/contrib/engines/common.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numbers
22
import warnings
33
from functools import partial
4-
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Union
4+
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Union, cast
55

66
import torch
77
import torch.nn as nn
@@ -47,7 +47,7 @@ def setup_common_training_handlers(
4747
clear_cuda_cache: bool = True,
4848
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
4949
**kwargs: Any
50-
):
50+
) -> None:
5151
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
5252
5353
- :class:`~ignite.handlers.TerminateOnNan`
@@ -88,24 +88,24 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
8888
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
8989
"""
9090

91-
_kwargs = dict(
92-
to_save=to_save,
93-
save_every_iters=save_every_iters,
94-
output_path=output_path,
95-
lr_scheduler=lr_scheduler,
96-
with_gpu_stats=with_gpu_stats,
97-
output_names=output_names,
98-
with_pbars=with_pbars,
99-
with_pbar_on_iters=with_pbar_on_iters,
100-
log_every_iters=log_every_iters,
101-
stop_on_nan=stop_on_nan,
102-
clear_cuda_cache=clear_cuda_cache,
103-
save_handler=save_handler,
104-
)
105-
_kwargs.update(kwargs)
106-
10791
if idist.get_world_size() > 1:
108-
_setup_common_distrib_training_handlers(trainer, train_sampler=train_sampler, **_kwargs)
92+
_setup_common_distrib_training_handlers(
93+
trainer,
94+
train_sampler=train_sampler,
95+
to_save=to_save,
96+
save_every_iters=save_every_iters,
97+
output_path=output_path,
98+
lr_scheduler=lr_scheduler,
99+
with_gpu_stats=with_gpu_stats,
100+
output_names=output_names,
101+
with_pbars=with_pbars,
102+
with_pbar_on_iters=with_pbar_on_iters,
103+
log_every_iters=log_every_iters,
104+
stop_on_nan=stop_on_nan,
105+
clear_cuda_cache=clear_cuda_cache,
106+
save_handler=save_handler,
107+
**kwargs,
108+
)
109109
else:
110110
if train_sampler is not None and isinstance(train_sampler, DistributedSampler):
111111
warnings.warn(
@@ -114,7 +114,22 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
114114
"Train sampler argument will be ignored",
115115
UserWarning,
116116
)
117-
_setup_common_training_handlers(trainer, **_kwargs)
117+
_setup_common_training_handlers(
118+
trainer,
119+
to_save=to_save,
120+
save_every_iters=save_every_iters,
121+
output_path=output_path,
122+
lr_scheduler=lr_scheduler,
123+
with_gpu_stats=with_gpu_stats,
124+
output_names=output_names,
125+
with_pbars=with_pbars,
126+
with_pbar_on_iters=with_pbar_on_iters,
127+
log_every_iters=log_every_iters,
128+
stop_on_nan=stop_on_nan,
129+
clear_cuda_cache=clear_cuda_cache,
130+
save_handler=save_handler,
131+
**kwargs,
132+
)
118133

119134

120135
setup_common_distrib_training_handlers = setup_common_training_handlers
@@ -135,7 +150,7 @@ def _setup_common_training_handlers(
135150
clear_cuda_cache: bool = True,
136151
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
137152
**kwargs: Any
138-
):
153+
) -> None:
139154
if output_path is not None and save_handler is not None:
140155
raise ValueError(
141156
"Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
@@ -146,7 +161,9 @@ def _setup_common_training_handlers(
146161

147162
if lr_scheduler is not None:
148163
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
149-
trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
164+
trainer.add_event_handler(
165+
Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()
166+
)
150167
elif isinstance(lr_scheduler, LRScheduler):
151168
trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
152169
else:
@@ -164,15 +181,19 @@ def _setup_common_training_handlers(
164181
if output_path is not None:
165182
save_handler = DiskSaver(dirname=output_path, require_empty=False)
166183

167-
checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs)
184+
checkpoint_handler = Checkpoint(
185+
to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs
186+
)
168187
trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)
169188

170189
if with_gpu_stats:
171-
GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters))
190+
GpuInfo().attach(
191+
trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) # type: ignore[arg-type]
192+
)
172193

173194
if output_names is not None:
174195

175-
def output_transform(x, index, name):
196+
def output_transform(x: Any, index: int, name: str) -> Any:
176197
if isinstance(x, Mapping):
177198
return x[name]
178199
elif isinstance(x, Sequence):
@@ -217,7 +238,7 @@ def _setup_common_distrib_training_handlers(
217238
clear_cuda_cache: bool = True,
218239
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
219240
**kwargs: Any
220-
):
241+
) -> None:
221242

222243
_setup_common_training_handlers(
223244
trainer,
@@ -241,18 +262,18 @@ def _setup_common_distrib_training_handlers(
241262
raise TypeError("Train sampler should be torch DistributedSampler and have `set_epoch` method")
242263

243264
@trainer.on(Events.EPOCH_STARTED)
244-
def distrib_set_epoch(engine):
245-
train_sampler.set_epoch(engine.state.epoch - 1)
265+
def distrib_set_epoch(engine: Engine) -> None:
266+
cast(DistributedSampler, train_sampler).set_epoch(engine.state.epoch - 1)
246267

247268

248-
def empty_cuda_cache(_):
269+
def empty_cuda_cache(_: Engine) -> None:
249270
torch.cuda.empty_cache()
250271
import gc
251272

252273
gc.collect()
253274

254275

255-
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters):
276+
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters) -> None: # type: ignore
256277
raise DeprecationWarning(
257278
"ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
258279
"Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
@@ -262,10 +283,10 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo
262283
def _setup_logging(
263284
logger: BaseLogger,
264285
trainer: Engine,
265-
optimizers: Union[Optimizer, Dict[str, Optimizer]],
266-
evaluators: Union[Engine, Dict[str, Engine]],
286+
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer], Dict[None, Optimizer]]],
287+
evaluators: Optional[Union[Engine, Dict[str, Engine]]],
267288
log_every_iters: int,
268-
):
289+
) -> None:
269290
if optimizers is not None:
270291
if not isinstance(optimizers, (Optimizer, Mapping)):
271292
raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers")
@@ -311,7 +332,7 @@ def setup_tb_logging(
311332
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
312333
log_every_iters: int = 100,
313334
**kwargs: Any
314-
):
335+
) -> TensorboardLogger:
315336
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
316337
317338
- Training metrics, e.g. running average loss values
@@ -343,7 +364,7 @@ def setup_visdom_logging(
343364
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
344365
log_every_iters: int = 100,
345366
**kwargs: Any
346-
):
367+
) -> VisdomLogger:
347368
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
348369
349370
- Training metrics, e.g. running average loss values
@@ -374,7 +395,7 @@ def setup_mlflow_logging(
374395
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
375396
log_every_iters: int = 100,
376397
**kwargs: Any
377-
):
398+
) -> MLflowLogger:
378399
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
379400
380401
- Training metrics, e.g. running average loss values
@@ -405,7 +426,7 @@ def setup_neptune_logging(
405426
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
406427
log_every_iters: int = 100,
407428
**kwargs: Any
408-
):
429+
) -> NeptuneLogger:
409430
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
410431
411432
- Training metrics, e.g. running average loss values
@@ -436,7 +457,7 @@ def setup_wandb_logging(
436457
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
437458
log_every_iters: int = 100,
438459
**kwargs: Any
439-
):
460+
) -> WandBLogger:
440461
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
441462
442463
- Training metrics, e.g. running average loss values
@@ -467,7 +488,7 @@ def setup_plx_logging(
467488
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
468489
log_every_iters: int = 100,
469490
**kwargs: Any
470-
):
491+
) -> PolyaxonLogger:
471492
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
472493
473494
- Training metrics, e.g. running average loss values
@@ -498,7 +519,7 @@ def setup_trains_logging(
498519
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
499520
log_every_iters: int = 100,
500521
**kwargs: Any
501-
):
522+
) -> TrainsLogger:
502523
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
503524
504525
- Training metrics, e.g. running average loss values
@@ -523,8 +544,8 @@ def setup_trains_logging(
523544
return logger
524545

525546

526-
def get_default_score_fn(metric_name: str):
527-
def wrapper(engine: Engine):
547+
def get_default_score_fn(metric_name: str) -> Any:
548+
def wrapper(engine: Engine) -> Any:
528549
score = engine.state.metrics[metric_name]
529550
return score
530551

@@ -540,7 +561,7 @@ def gen_save_best_models_by_val_score(
540561
trainer: Optional[Engine] = None,
541562
tag: str = "val",
542563
**kwargs: Any
543-
):
564+
) -> Checkpoint:
544565
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
545566
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
546567
Models with highest metric value will be retained. The logic of how to store objects is delegated to
@@ -570,9 +591,10 @@ def gen_save_best_models_by_val_score(
570591
if trainer is not None:
571592
global_step_transform = global_step_from_engine(trainer)
572593

573-
to_save = models
574594
if isinstance(models, nn.Module):
575-
to_save = {"model": models}
595+
to_save = {"model": models} # type: Dict[str, nn.Module]
596+
else:
597+
to_save = models
576598

577599
best_model_handler = Checkpoint(
578600
to_save,
@@ -598,7 +620,7 @@ def save_best_model_by_val_score(
598620
trainer: Optional[Engine] = None,
599621
tag: str = "val",
600622
**kwargs: Any
601-
):
623+
) -> Checkpoint:
602624
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
603625
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
604626
Models with highest metric value will be retained.
@@ -629,7 +651,9 @@ def save_best_model_by_val_score(
629651
)
630652

631653

632-
def add_early_stopping_by_val_score(patience: int, evaluator: Engine, trainer: Engine, metric_name: str):
654+
def add_early_stopping_by_val_score(
655+
patience: int, evaluator: Engine, trainer: Engine, metric_name: str
656+
) -> EarlyStopping:
633657
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
634658
Metric value should increase in order to keep training and not early stop.
635659

ignite/contrib/engines/tbptt.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# coding: utf-8
2+
import collections.abc as collections
23
from typing import Callable, Mapping, Optional, Sequence, Union
34

45
import torch
@@ -20,7 +21,9 @@ class Tbptt_Events(EventEnum):
2021
TIME_ITERATION_COMPLETED = "time_iteration_completed"
2122

2223

23-
def _detach_hidden(hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]):
24+
def _detach_hidden(
25+
hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]
26+
) -> Union[torch.Tensor, collections.Sequence, collections.Mapping, str, bytes]:
2427
"""Cut backpropagation graph.
2528
2629
Auxillary function to cut the backpropagation graph by detaching the hidden
@@ -38,7 +41,7 @@ def create_supervised_tbptt_trainer(
3841
device: Optional[str] = None,
3942
non_blocking: bool = False,
4043
prepare_batch: Callable = _prepare_batch,
41-
):
44+
) -> Engine:
4245
"""Create a trainer for truncated backprop through time supervised models.
4346
4447
Training recurrent model on long sequences is computationally intensive as
@@ -83,7 +86,7 @@ def create_supervised_tbptt_trainer(
8386
8487
"""
8588

86-
def _update(engine: Engine, batch: Sequence[torch.Tensor]):
89+
def _update(engine: Engine, batch: Sequence[torch.Tensor]) -> float:
8790
loss_list = []
8891
hidden = None
8992

ignite/contrib/handlers/tqdm_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def attach(
159159
engine: Engine,
160160
metric_names: Optional[str] = None,
161161
output_transform: Optional[Callable] = None,
162-
event_name: Events = Events.ITERATION_COMPLETED,
162+
event_name: Union[CallableEventWithFilter, Events] = Events.ITERATION_COMPLETED,
163163
closing_event_name: Events = Events.EPOCH_COMPLETED,
164164
):
165165
"""

mypy.ini

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ warn_unused_ignores = True
2828

2929
ignore_errors = True
3030

31-
[mypy-ignite.contrib.engines.*]
32-
33-
ignore_errors = True
34-
3531
[mypy-horovod.*]
3632
ignore_missing_imports = True
3733

0 commit comments

Comments
 (0)