Skip to content

Commit 60c575a

Browse files
committed
review comments
1 parent 634947f commit 60c575a

25 files changed

+151
-178
lines changed

ignite/contrib/engines/common.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,19 @@ def setup_common_training_handlers(
3535
trainer: Engine,
3636
train_sampler: Optional[DistributedSampler] = None,
3737
to_save: Optional[Dict[str, Any]] = None,
38-
save_every_iters: Optional[int] = 1000,
38+
save_every_iters: int = 1000,
3939
output_path: Optional[str] = None,
4040
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
41-
with_gpu_stats: Optional[bool] = False,
41+
with_gpu_stats: bool = False,
4242
output_names: Optional[Iterable[str]] = None,
43-
with_pbars: Optional[bool] = True,
44-
with_pbar_on_iters: Optional[bool] = True,
45-
log_every_iters: Optional[int] = 100,
43+
with_pbars: bool = True,
44+
with_pbar_on_iters: bool = True,
45+
log_every_iters: int = 100,
4646
device: Optional[Union[str, torch.device]] = None,
47-
stop_on_nan: Optional[bool] = True,
48-
clear_cuda_cache: Optional[bool] = True,
47+
stop_on_nan: bool = True,
48+
clear_cuda_cache: bool = True,
4949
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
50-
**kwargs: Mapping,
50+
**kwargs: Any,
5151
):
5252
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
5353
- :class:`~ignite.handlers.TerminateOnNan`
@@ -126,18 +126,18 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
126126
def _setup_common_training_handlers(
127127
trainer: Engine,
128128
to_save: Optional[Dict[str, Any]] = None,
129-
save_every_iters: Optional[int] = 1000,
129+
save_every_iters: int = 1000,
130130
output_path: Optional[str] = None,
131131
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
132-
with_gpu_stats: Optional[bool] = False,
132+
with_gpu_stats: bool = False,
133133
output_names: Optional[Iterable[str]] = None,
134-
with_pbars: Optional[bool] = True,
135-
with_pbar_on_iters: Optional[bool] = True,
136-
log_every_iters: Optional[int] = 100,
137-
stop_on_nan: Optional[bool] = True,
138-
clear_cuda_cache: Optional[bool] = True,
134+
with_pbars: bool = True,
135+
with_pbar_on_iters: bool = True,
136+
log_every_iters: int = 100,
137+
stop_on_nan: bool = True,
138+
clear_cuda_cache: bool = True,
139139
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
140-
**kwargs: Mapping,
140+
**kwargs: Any,
141141
):
142142
if output_path is not None and save_handler is not None:
143143
raise ValueError(
@@ -208,18 +208,18 @@ def _setup_common_distrib_training_handlers(
208208
trainer: Engine,
209209
train_sampler: Optional[DistributedSampler] = None,
210210
to_save: Optional[Dict[str, Any]] = None,
211-
save_every_iters: Optional[int] = 1000,
211+
save_every_iters: int = 1000,
212212
output_path: Optional[str] = None,
213213
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
214-
with_gpu_stats: Optional[bool] = False,
214+
with_gpu_stats: bool = False,
215215
output_names: Optional[Iterable[str]] = None,
216-
with_pbars: Optional[bool] = True,
217-
with_pbar_on_iters: Optional[bool] = True,
218-
log_every_iters: Optional[int] = 100,
219-
stop_on_nan: Optional[bool] = True,
220-
clear_cuda_cache: Optional[bool] = True,
216+
with_pbars: bool = True,
217+
with_pbar_on_iters: bool = True,
218+
log_every_iters: int = 100,
219+
stop_on_nan: bool = True,
220+
clear_cuda_cache: bool = True,
221221
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
222-
**kwargs: Mapping,
222+
**kwargs: Any,
223223
):
224224

225225
_setup_common_training_handlers(
@@ -265,9 +265,9 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo
265265
def _setup_logging(
266266
logger: BaseLogger,
267267
trainer: Engine,
268-
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]],
269-
evaluators: Optional[Union[Engine, Dict[str, Engine]]],
270-
log_every_iters: Optional[int],
268+
optimizers: Union[Optimizer, Dict[str, Optimizer]],
269+
evaluators: Union[Engine, Dict[str, Engine]],
270+
log_every_iters: int,
271271
):
272272
if optimizers is not None:
273273
if not isinstance(optimizers, (Optimizer, Mapping)):
@@ -312,8 +312,8 @@ def setup_tb_logging(
312312
trainer: Engine,
313313
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
314314
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
315-
log_every_iters: Optional[int] = 100,
316-
**kwargs: Mapping,
315+
log_every_iters: int = 100,
316+
**kwargs: Any,
317317
):
318318
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
319319
- Training metrics, e.g. running average loss values
@@ -343,8 +343,8 @@ def setup_visdom_logging(
343343
trainer: Engine,
344344
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
345345
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
346-
log_every_iters: Optional[int] = 100,
347-
**kwargs: Mapping,
346+
log_every_iters: int = 100,
347+
**kwargs: Any,
348348
):
349349
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
350350
- Training metrics, e.g. running average loss values
@@ -373,8 +373,8 @@ def setup_mlflow_logging(
373373
trainer: Engine,
374374
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
375375
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
376-
log_every_iters: Optional[int] = 100,
377-
**kwargs: Mapping,
376+
log_every_iters: int = 100,
377+
**kwargs: Any,
378378
):
379379
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
380380
- Training metrics, e.g. running average loss values
@@ -403,8 +403,8 @@ def setup_neptune_logging(
403403
trainer: Engine,
404404
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
405405
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
406-
log_every_iters: Optional[int] = 100,
407-
**kwargs: Mapping,
406+
log_every_iters: int = 100,
407+
**kwargs: Any,
408408
):
409409
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
410410
- Training metrics, e.g. running average loss values
@@ -433,8 +433,8 @@ def setup_wandb_logging(
433433
trainer: Engine,
434434
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
435435
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
436-
log_every_iters: Optional[int] = 100,
437-
**kwargs: Mapping,
436+
log_every_iters: int = 100,
437+
**kwargs: Any,
438438
):
439439
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
440440
- Training metrics, e.g. running average loss values
@@ -463,8 +463,8 @@ def setup_plx_logging(
463463
trainer: Engine,
464464
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
465465
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
466-
log_every_iters: Optional[int] = 100,
467-
**kwargs: Mapping,
466+
log_every_iters: int = 100,
467+
**kwargs: Any,
468468
):
469469
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
470470
- Training metrics, e.g. running average loss values
@@ -493,8 +493,8 @@ def setup_trains_logging(
493493
trainer: Engine,
494494
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
495495
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
496-
log_every_iters: Optional[int] = 100,
497-
**kwargs: Mapping,
496+
log_every_iters: int = 100,
497+
**kwargs: Any,
498498
):
499499
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
500500
- Training metrics, e.g. running average loss values
@@ -532,10 +532,10 @@ def gen_save_best_models_by_val_score(
532532
evaluator: Engine,
533533
models: torch.nn.Module,
534534
metric_name: str,
535-
n_saved: Optional[int] = 3,
535+
n_saved: int = 3,
536536
trainer: Optional[Engine] = None,
537-
tag: Optional[str] = "val",
538-
**kwargs: Mapping,
537+
tag: str = "val",
538+
**kwargs: Any,
539539
):
540540
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
541541
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
@@ -590,10 +590,10 @@ def save_best_model_by_val_score(
590590
evaluator: Engine,
591591
model: torch.nn.Module,
592592
metric_name: str,
593-
n_saved: Optional[int] = 3,
593+
n_saved: int = 3,
594594
trainer: Optional[Engine] = None,
595-
tag: Optional[str] = "val",
596-
**kwargs: Mapping,
595+
tag: str = "val",
596+
**kwargs: Any,
597597
):
598598
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
599599
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).

ignite/contrib/engines/tbptt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def create_supervised_tbptt_trainer(
3535
optimizer: Optimizer,
3636
loss_fn: _Loss,
3737
tbtt_step: int,
38-
dim: Optional[int] = 0,
38+
dim: int = 0,
3939
device: Optional[str] = None,
40-
non_blocking: Optional[bool] = False,
41-
prepare_batch: Optional[Callable] = _prepare_batch,
40+
non_blocking: bool = False,
41+
prepare_batch: Callable = _prepare_batch,
4242
):
4343
"""Create a trainer for truncated backprop through time supervised models.
4444
@@ -59,7 +59,7 @@ def create_supervised_tbptt_trainer(
5959
optimizer (`torch.optim.Optimizer`): the optimizer to use.
6060
loss_fn (torch.nn loss function): the loss function to use.
6161
tbtt_step (int): the length of time chunks (last one may be smaller).
62-
dim (int, optional): axis representing the time dimension.
62+
dim (int): axis representing the time dimension.
6363
device (str, optional): device type specification (default: None).
6464
Applies to batches.
6565
non_blocking (bool, optional): if True and this copy is between CPU and GPU,
@@ -84,7 +84,7 @@ def create_supervised_tbptt_trainer(
8484
8585
"""
8686

87-
def _update(engine: Engine, batch: Tuple[torch.Tensor, torch.Tensor]):
87+
def _update(engine: Engine, batch: Sequence[torch.Tensor]):
8888
loss_list = []
8989
hidden = None
9090

ignite/contrib/handlers/base_logger.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numbers
22
import warnings
33
from abc import ABCMeta, abstractmethod
4-
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union
4+
from typing import Any, Callable, List, Optional, Sequence, Union
55

66
import torch
77
from torch.optim import Optimizer
@@ -20,7 +20,7 @@ class BaseOptimizerParamsHandler(BaseHandler):
2020
Base handler for logging optimizer parameters
2121
"""
2222

23-
def __init__(self, optimizer: Optimizer, param_name: Optional[str] = "lr", tag: Optional[str] = None):
23+
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
2424
if not (
2525
isinstance(optimizer, Optimizer)
2626
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
@@ -107,7 +107,7 @@ class BaseWeightsScalarHandler(BaseHandler):
107107
Helper handler to log model's weights as scalars.
108108
"""
109109

110-
def __init__(self, model: torch.nn.Module, reduction: Optional[Callable] = torch.norm, tag: Optional[str] = None):
110+
def __init__(self, model: torch.nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
111111
if not isinstance(model, torch.nn.Module):
112112
raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model)))
113113

@@ -166,7 +166,7 @@ def attach(self, engine: Engine, log_handler: Callable, event_name: Any):
166166

167167
return engine.add_event_handler(event_name, log_handler, self, name)
168168

169-
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Mapping):
169+
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
170170
"""Shortcut method to attach `OutputHandler` to the logger.
171171
172172
Args:
@@ -182,7 +182,7 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k
182182
"""
183183
return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name)
184184

185-
def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Mapping):
185+
def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
186186
"""Shortcut method to attach `OptimizerParamsHandler` to the logger.
187187
188188
Args:
@@ -199,11 +199,11 @@ def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any,
199199
self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name)
200200

201201
@abstractmethod
202-
def _create_output_handler(self, engine, *args: Any, **kwargs: Mapping):
202+
def _create_output_handler(self, engine, *args: Any, **kwargs: Any):
203203
pass
204204

205205
@abstractmethod
206-
def _create_opt_params_handler(self, *args: Any, **kwargs: Mapping):
206+
def _create_opt_params_handler(self, *args: Any, **kwargs: Any):
207207
pass
208208

209209
def __enter__(self):

ignite/contrib/handlers/lr_finder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def get_results(self):
186186
"""
187187
return self._history
188188

189-
def plot(self, skip_start: Optional[int] = 10, skip_end: Optional[int] = 5, log_lr: Optional[bool] = True):
189+
def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
190190
"""Plots the learning rate range test.
191191
192192
This method requires `matplotlib` package to be installed:
@@ -255,12 +255,12 @@ def attach(
255255
self,
256256
trainer: Engine,
257257
to_save: Dict[str, Any],
258-
output_transform: Optional[Callable] = lambda output: output,
258+
output_transform: Callable = lambda output: output,
259259
num_iter: Optional[int] = None,
260-
end_lr: Optional[float] = 10.0,
261-
step_mode: Optional[str] = "exp",
262-
smooth_f: Optional[float] = 0.05,
263-
diverge_th: Optional[float] = 5.0,
260+
end_lr: float = 10.0,
261+
step_mode: str = "exp",
262+
smooth_f: float = 0.05,
263+
diverge_th: float = 5.0,
264264
):
265265
"""Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.
266266
@@ -372,7 +372,7 @@ class _ExponentialLR(_LRScheduler):
372372
373373
"""
374374

375-
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: Optional[int] = -1):
375+
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
376376
self.end_lr = end_lr
377377
self.num_iter = num_iter
378378
super(_ExponentialLR, self).__init__(optimizer, last_epoch)

ignite/contrib/handlers/mlflow_logger.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numbers
22
import warnings
3-
from typing import Any, Callable, List, Mapping, Optional, Union
3+
from typing import Any, Callable, List, Optional, Union
44

55
import torch
66
from torch.optim import Optimizer
@@ -113,10 +113,10 @@ def close(self):
113113

114114
mlflow.end_run()
115115

116-
def _create_output_handler(self, *args: Any, **kwargs: Mapping):
116+
def _create_output_handler(self, *args: Any, **kwargs: Any):
117117
return OutputHandler(*args, **kwargs)
118118

119-
def _create_opt_params_handler(self, *args: Any, **kwargs: Mapping):
119+
def _create_opt_params_handler(self, *args: Any, **kwargs: Any):
120120
return OptimizerParamsHandler(*args, **kwargs)
121121

122122

@@ -290,7 +290,7 @@ class OptimizerParamsHandler(BaseOptimizerParamsHandler):
290290
tag (str, optional): common title for all produced plots. For example, 'generator'
291291
"""
292292

293-
def __init__(self, optimizer: Optimizer, param_name: Optional[str] = "lr", tag: Optional[str] = None):
293+
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
294294
super(OptimizerParamsHandler, self).__init__(optimizer, param_name, tag)
295295

296296
def __call__(self, engine: Engine, logger: MLflowLogger, event_name):

0 commit comments

Comments
 (0)