Skip to content

Commit 37914c3

Browse files
committed
review changes
1 parent 60c575a commit 37914c3

File tree

9 files changed

+33
-34
lines changed

9 files changed

+33
-34
lines changed

ignite/contrib/engines/common.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
def setup_common_training_handlers(
3535
trainer: Engine,
3636
train_sampler: Optional[DistributedSampler] = None,
37-
to_save: Optional[Dict[str, Any]] = None,
37+
to_save: Optional[Mapping] = None,
3838
save_every_iters: int = 1000,
3939
output_path: Optional[str] = None,
4040
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
@@ -47,7 +47,7 @@ def setup_common_training_handlers(
4747
stop_on_nan: bool = True,
4848
clear_cuda_cache: bool = True,
4949
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
50-
**kwargs: Any,
50+
**kwargs: Any
5151
):
5252
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
5353
- :class:`~ignite.handlers.TerminateOnNan`
@@ -125,7 +125,7 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
125125

126126
def _setup_common_training_handlers(
127127
trainer: Engine,
128-
to_save: Optional[Dict[str, Any]] = None,
128+
to_save: Optional[Mapping] = None,
129129
save_every_iters: int = 1000,
130130
output_path: Optional[str] = None,
131131
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
@@ -137,7 +137,7 @@ def _setup_common_training_handlers(
137137
stop_on_nan: bool = True,
138138
clear_cuda_cache: bool = True,
139139
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
140-
**kwargs: Any,
140+
**kwargs: Any
141141
):
142142
if output_path is not None and save_handler is not None:
143143
raise ValueError(
@@ -207,7 +207,7 @@ def output_transform(x, index, name):
207207
def _setup_common_distrib_training_handlers(
208208
trainer: Engine,
209209
train_sampler: Optional[DistributedSampler] = None,
210-
to_save: Optional[Dict[str, Any]] = None,
210+
to_save: Optional[Mapping] = None,
211211
save_every_iters: int = 1000,
212212
output_path: Optional[str] = None,
213213
lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
@@ -219,7 +219,7 @@ def _setup_common_distrib_training_handlers(
219219
stop_on_nan: bool = True,
220220
clear_cuda_cache: bool = True,
221221
save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
222-
**kwargs: Any,
222+
**kwargs: Any
223223
):
224224

225225
_setup_common_training_handlers(
@@ -313,7 +313,7 @@ def setup_tb_logging(
313313
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
314314
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
315315
log_every_iters: int = 100,
316-
**kwargs: Any,
316+
**kwargs: Any
317317
):
318318
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
319319
- Training metrics, e.g. running average loss values
@@ -344,7 +344,7 @@ def setup_visdom_logging(
344344
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
345345
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
346346
log_every_iters: int = 100,
347-
**kwargs: Any,
347+
**kwargs: Any
348348
):
349349
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
350350
- Training metrics, e.g. running average loss values
@@ -374,7 +374,7 @@ def setup_mlflow_logging(
374374
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
375375
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
376376
log_every_iters: int = 100,
377-
**kwargs: Any,
377+
**kwargs: Any
378378
):
379379
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
380380
- Training metrics, e.g. running average loss values
@@ -404,7 +404,7 @@ def setup_neptune_logging(
404404
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
405405
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
406406
log_every_iters: int = 100,
407-
**kwargs: Any,
407+
**kwargs: Any
408408
):
409409
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
410410
- Training metrics, e.g. running average loss values
@@ -434,7 +434,7 @@ def setup_wandb_logging(
434434
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
435435
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
436436
log_every_iters: int = 100,
437-
**kwargs: Any,
437+
**kwargs: Any
438438
):
439439
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
440440
- Training metrics, e.g. running average loss values
@@ -464,7 +464,7 @@ def setup_plx_logging(
464464
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
465465
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
466466
log_every_iters: int = 100,
467-
**kwargs: Any,
467+
**kwargs: Any
468468
):
469469
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
470470
- Training metrics, e.g. running average loss values
@@ -494,7 +494,7 @@ def setup_trains_logging(
494494
optimizers: Optional[Union[Optimizer, Dict[str, Optimizer]]] = None,
495495
evaluators: Optional[Union[Engine, Dict[str, Engine]]] = None,
496496
log_every_iters: int = 100,
497-
**kwargs: Any,
497+
**kwargs: Any
498498
):
499499
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
500500
- Training metrics, e.g. running average loss values
@@ -530,12 +530,12 @@ def wrapper(engine: Engine):
530530
def gen_save_best_models_by_val_score(
531531
save_handler: Union[Callable, BaseSaveHandler],
532532
evaluator: Engine,
533-
models: torch.nn.Module,
533+
models: Union[torch.nn.Module, Dict[str, torch.nn.Module]],
534534
metric_name: str,
535535
n_saved: int = 3,
536536
trainer: Optional[Engine] = None,
537537
tag: str = "val",
538-
**kwargs: Any,
538+
**kwargs: Any
539539
):
540540
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
541541
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
@@ -593,7 +593,7 @@ def save_best_model_by_val_score(
593593
n_saved: int = 3,
594594
trainer: Optional[Engine] = None,
595595
tag: str = "val",
596-
**kwargs: Any,
596+
**kwargs: Any
597597
):
598598
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
599599
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).

ignite/contrib/engines/tbptt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# coding: utf-8
2-
from typing import Callable, Mapping, Optional, Sequence, Tuple, Union
2+
from typing import Callable, Mapping, Optional, Sequence, Union
33

44
import torch
5-
from torch.nn import Module
6-
from torch.nn.modules.loss import _Loss
5+
import torch.nn as nn
76
from torch.optim.optimizer import Optimizer
87

98
from ignite.engine import Engine, EventEnum, _prepare_batch
@@ -31,9 +30,9 @@ def _detach_hidden(hidden: Union[torch.Tensor, Sequence, Mapping, str, bytes]):
3130

3231

3332
def create_supervised_tbptt_trainer(
34-
model: Module,
33+
model: nn.Module,
3534
optimizer: Optimizer,
36-
loss_fn: _Loss,
35+
loss_fn: nn.Module,
3736
tbtt_step: int,
3837
dim: int = 0,
3938
device: Optional[str] = None,

ignite/contrib/handlers/lr_finder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tempfile
55
import warnings
66
from pathlib import Path
7-
from typing import Any, Callable, Dict, Mapping, Optional
7+
from typing import Callable, Mapping, Optional
88

99
import torch
1010
from torch.optim import Optimizer
@@ -254,7 +254,7 @@ def lr_suggestion(self):
254254
def attach(
255255
self,
256256
trainer: Engine,
257-
to_save: Dict[str, Any],
257+
to_save: Mapping,
258258
output_transform: Callable = lambda output: output,
259259
num_iter: Optional[int] = None,
260260
end_lr: float = 10.0,

ignite/contrib/handlers/param_scheduler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77
from copy import copy
88
from pathlib import Path
9-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
9+
from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
1010

1111
import torch
1212
from torch.optim.lr_scheduler import _LRScheduler
@@ -142,7 +142,7 @@ def get_param(self) -> Union[List[float], float]:
142142
pass
143143

144144
@classmethod
145-
def simulate_values(cls, num_events: int, **scheduler_kwargs: Mapping):
145+
def simulate_values(cls, num_events: int, **scheduler_kwargs: Any):
146146
"""Method to simulate scheduled values during `num_events` events.
147147
148148
Args:
@@ -518,7 +518,7 @@ def state_dict(self):
518518
state_dict["schedulers"].append(s.state_dict())
519519
return state_dict
520520

521-
def load_state_dict(self, state_dict: Dict[str, Any]):
521+
def load_state_dict(self, state_dict: Mapping):
522522
"""Copies parameters from :attr:`state_dict` into this ConcatScheduler.
523523
524524
Args:
@@ -583,7 +583,7 @@ def simulate_values(
583583
schedulers: List[ParamScheduler],
584584
durations: List[int],
585585
param_names: Union[List[str], Tuple[str]] = None,
586-
**kwargs: Any,
586+
**kwargs: Any
587587
):
588588
"""Method to simulate scheduled values during num_events events.
589589
@@ -1048,7 +1048,7 @@ def state_dict(self):
10481048
state_dict["schedulers"].append((n, s.state_dict()))
10491049
return state_dict
10501050

1051-
def load_state_dict(self, state_dict: Dict):
1051+
def load_state_dict(self, state_dict: Mapping):
10521052
"""Copies parameters from :attr:`state_dict` into this ParamScheduler.
10531053
10541054
Args:

ignite/contrib/handlers/time_profilers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def to_str(v: Union[str, tuple]):
388388
return "{:.5f}/{}".format(v[0], v[1])
389389
return "{:.5f}".format(v)
390390

391-
def odict_to_str(d: Dict):
391+
def odict_to_str(d: Mapping):
392392
out = " | ".join([to_str(v) for v in d.values()])
393393
return out
394394

ignite/contrib/handlers/tqdm_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def __init__(
232232
description: str,
233233
metric_names: Optional[str] = None,
234234
output_transform: Optional[Callable] = None,
235-
closing_event_name: Any = Events.EPOCH_COMPLETED,
235+
closing_event_name: EventEnum = Events.EPOCH_COMPLETED,
236236
):
237237
if metric_names is None and output_transform is None:
238238
# This helps to avoid 'Either metric_names or output_transform should be defined' of BaseOutputHandler

ignite/contrib/metrics/average_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class AveragePrecision(EpochMetric):
2424
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
2525
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2626
you want to compute the metric with respect to one of the outputs.
27-
check_compute_fn (bool, optional): Optional default False. If True, `average_precision_score
27+
check_compute_fn (bool): Default False. If True, `average_precision_score
2828
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
2929
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are
3030
no issues. User will be warned in case there are any issues computing the function.

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class PrecisionRecallCurve(EpochMetric):
2525
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
2626
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2727
you want to compute the metric with respect to one of the outputs.
28-
check_compute_fn (bool, optional): Optional default False. If True, `precision_recall_curve
28+
check_compute_fn (bool): Default False. If True, `precision_recall_curve
2929
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
3030
#sklearn.metrics.precision_recall_curve>`_ is run on the first batch of data to ensure there are
3131
no issues. User will be warned in case there are any issues computing the function.

ignite/contrib/metrics/roc_auc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ROC_AUC(EpochMetric):
3636
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
3737
form expected by the metric. This can be useful if, for example, you have a multi-output model and
3838
you want to compute the metric with respect to one of the outputs.
39-
check_compute_fn (bool, optional): Optional default False. If True, `roc_curve
39+
check_compute_fn (bool): Default False. If True, `roc_curve
4040
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#
4141
sklearn.metrics.roc_auc_score>`_ is run on the first batch of data to ensure there are
4242
no issues. User will be warned in case there are any issues computing the function.
@@ -72,7 +72,7 @@ class RocCurve(EpochMetric):
7272
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
7373
form expected by the metric. This can be useful if, for example, you have a multi-output model and
7474
you want to compute the metric with respect to one of the outputs.
75-
check_compute_fn (bool, optional): Optional default False. If True, `sklearn.metrics.roc_curve
75+
check_compute_fn (bool): Default False. If True, `sklearn.metrics.roc_curve
7676
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
7777
sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
7878
no issues. User will be warned in case there are any issues computing the function.

0 commit comments

Comments
 (0)