1
1
import numbers
2
2
import warnings
3
3
from functools import partial
4
- from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Union
4
+ from typing import Any , Callable , Dict , Iterable , Mapping , Optional , Sequence , Tuple , Union , cast
5
5
6
6
import torch
7
7
import torch .nn as nn
@@ -47,7 +47,7 @@ def setup_common_training_handlers(
47
47
clear_cuda_cache : bool = True ,
48
48
save_handler : Optional [Union [Callable , BaseSaveHandler ]] = None ,
49
49
** kwargs : Any
50
- ):
50
+ ) -> None :
51
51
"""Helper method to setup trainer with common handlers (it also supports distributed configuration):
52
52
53
53
- :class:`~ignite.handlers.TerminateOnNan`
@@ -88,24 +88,24 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
88
88
**kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
89
89
"""
90
90
91
- _kwargs = dict (
92
- to_save = to_save ,
93
- save_every_iters = save_every_iters ,
94
- output_path = output_path ,
95
- lr_scheduler = lr_scheduler ,
96
- with_gpu_stats = with_gpu_stats ,
97
- output_names = output_names ,
98
- with_pbars = with_pbars ,
99
- with_pbar_on_iters = with_pbar_on_iters ,
100
- log_every_iters = log_every_iters ,
101
- stop_on_nan = stop_on_nan ,
102
- clear_cuda_cache = clear_cuda_cache ,
103
- save_handler = save_handler ,
104
- )
105
- _kwargs .update (kwargs )
106
-
107
91
if idist .get_world_size () > 1 :
108
- _setup_common_distrib_training_handlers (trainer , train_sampler = train_sampler , ** _kwargs )
92
+ _setup_common_distrib_training_handlers (
93
+ trainer ,
94
+ train_sampler = train_sampler ,
95
+ to_save = to_save ,
96
+ save_every_iters = save_every_iters ,
97
+ output_path = output_path ,
98
+ lr_scheduler = lr_scheduler ,
99
+ with_gpu_stats = with_gpu_stats ,
100
+ output_names = output_names ,
101
+ with_pbars = with_pbars ,
102
+ with_pbar_on_iters = with_pbar_on_iters ,
103
+ log_every_iters = log_every_iters ,
104
+ stop_on_nan = stop_on_nan ,
105
+ clear_cuda_cache = clear_cuda_cache ,
106
+ save_handler = save_handler ,
107
+ ** kwargs ,
108
+ )
109
109
else :
110
110
if train_sampler is not None and isinstance (train_sampler , DistributedSampler ):
111
111
warnings .warn (
@@ -114,7 +114,22 @@ class to use to store ``to_save``. See :class:`~ignite.handlers.checkpoint.Check
114
114
"Train sampler argument will be ignored" ,
115
115
UserWarning ,
116
116
)
117
- _setup_common_training_handlers (trainer , ** _kwargs )
117
+ _setup_common_training_handlers (
118
+ trainer ,
119
+ to_save = to_save ,
120
+ save_every_iters = save_every_iters ,
121
+ output_path = output_path ,
122
+ lr_scheduler = lr_scheduler ,
123
+ with_gpu_stats = with_gpu_stats ,
124
+ output_names = output_names ,
125
+ with_pbars = with_pbars ,
126
+ with_pbar_on_iters = with_pbar_on_iters ,
127
+ log_every_iters = log_every_iters ,
128
+ stop_on_nan = stop_on_nan ,
129
+ clear_cuda_cache = clear_cuda_cache ,
130
+ save_handler = save_handler ,
131
+ ** kwargs ,
132
+ )
118
133
119
134
120
135
setup_common_distrib_training_handlers = setup_common_training_handlers
@@ -135,7 +150,7 @@ def _setup_common_training_handlers(
135
150
clear_cuda_cache : bool = True ,
136
151
save_handler : Optional [Union [Callable , BaseSaveHandler ]] = None ,
137
152
** kwargs : Any
138
- ):
153
+ ) -> None :
139
154
if output_path is not None and save_handler is not None :
140
155
raise ValueError (
141
156
"Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
@@ -146,7 +161,9 @@ def _setup_common_training_handlers(
146
161
147
162
if lr_scheduler is not None :
148
163
if isinstance (lr_scheduler , torch .optim .lr_scheduler ._LRScheduler ):
149
- trainer .add_event_handler (Events .ITERATION_COMPLETED , lambda engine : lr_scheduler .step ())
164
+ trainer .add_event_handler (
165
+ Events .ITERATION_COMPLETED , lambda engine : cast (_LRScheduler , lr_scheduler ).step ()
166
+ )
150
167
elif isinstance (lr_scheduler , LRScheduler ):
151
168
trainer .add_event_handler (Events .ITERATION_COMPLETED , lr_scheduler )
152
169
else :
@@ -164,15 +181,19 @@ def _setup_common_training_handlers(
164
181
if output_path is not None :
165
182
save_handler = DiskSaver (dirname = output_path , require_empty = False )
166
183
167
- checkpoint_handler = Checkpoint (to_save , save_handler , filename_prefix = "training" , ** kwargs )
184
+ checkpoint_handler = Checkpoint (
185
+ to_save , cast (Union [Callable , BaseSaveHandler ], save_handler ), filename_prefix = "training" , ** kwargs
186
+ )
168
187
trainer .add_event_handler (Events .ITERATION_COMPLETED (every = save_every_iters ), checkpoint_handler )
169
188
170
189
if with_gpu_stats :
171
- GpuInfo ().attach (trainer , name = "gpu" , event_name = Events .ITERATION_COMPLETED (every = log_every_iters ))
190
+ GpuInfo ().attach (
191
+ trainer , name = "gpu" , event_name = Events .ITERATION_COMPLETED (every = log_every_iters ) # type: ignore[arg-type]
192
+ )
172
193
173
194
if output_names is not None :
174
195
175
- def output_transform (x , index , name ) :
196
+ def output_transform (x : Any , index : int , name : str ) -> Any :
176
197
if isinstance (x , Mapping ):
177
198
return x [name ]
178
199
elif isinstance (x , Sequence ):
@@ -217,7 +238,7 @@ def _setup_common_distrib_training_handlers(
217
238
clear_cuda_cache : bool = True ,
218
239
save_handler : Optional [Union [Callable , BaseSaveHandler ]] = None ,
219
240
** kwargs : Any
220
- ):
241
+ ) -> None :
221
242
222
243
_setup_common_training_handlers (
223
244
trainer ,
@@ -241,18 +262,18 @@ def _setup_common_distrib_training_handlers(
241
262
raise TypeError ("Train sampler should be torch DistributedSampler and have `set_epoch` method" )
242
263
243
264
@trainer .on (Events .EPOCH_STARTED )
244
- def distrib_set_epoch (engine ) :
245
- train_sampler .set_epoch (engine .state .epoch - 1 )
265
+ def distrib_set_epoch (engine : Engine ) -> None :
266
+ cast ( DistributedSampler , train_sampler ) .set_epoch (engine .state .epoch - 1 )
246
267
247
268
248
- def empty_cuda_cache (_ ) :
269
+ def empty_cuda_cache (_ : Engine ) -> None :
249
270
torch .cuda .empty_cache ()
250
271
import gc
251
272
252
273
gc .collect ()
253
274
254
275
255
- def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ):
276
+ def setup_any_logging (logger , logger_module , trainer , optimizers , evaluators , log_every_iters ) -> None : # type: ignore
256
277
raise DeprecationWarning (
257
278
"ignite.contrib.engines.common.setup_any_logging is deprecated since 0.4.0. and will be remove in 0.6.0. "
258
279
"Please use instead: setup_tb_logging, setup_visdom_logging or setup_mlflow_logging etc."
@@ -262,10 +283,10 @@ def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, lo
262
283
def _setup_logging (
263
284
logger : BaseLogger ,
264
285
trainer : Engine ,
265
- optimizers : Union [Optimizer , Dict [str , Optimizer ]],
266
- evaluators : Union [Engine , Dict [str , Engine ]],
286
+ optimizers : Optional [ Union [Optimizer , Dict [str , Optimizer ], Dict [ None , Optimizer ] ]],
287
+ evaluators : Optional [ Union [Engine , Dict [str , Engine ] ]],
267
288
log_every_iters : int ,
268
- ):
289
+ ) -> None :
269
290
if optimizers is not None :
270
291
if not isinstance (optimizers , (Optimizer , Mapping )):
271
292
raise TypeError ("Argument optimizers should be either a single optimizer or a dictionary or optimizers" )
@@ -311,7 +332,7 @@ def setup_tb_logging(
311
332
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
312
333
log_every_iters : int = 100 ,
313
334
** kwargs : Any
314
- ):
335
+ ) -> TensorboardLogger :
315
336
"""Method to setup TensorBoard logging on trainer and a list of evaluators. Logged metrics are:
316
337
317
338
- Training metrics, e.g. running average loss values
@@ -343,7 +364,7 @@ def setup_visdom_logging(
343
364
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
344
365
log_every_iters : int = 100 ,
345
366
** kwargs : Any
346
- ):
367
+ ) -> VisdomLogger :
347
368
"""Method to setup Visdom logging on trainer and a list of evaluators. Logged metrics are:
348
369
349
370
- Training metrics, e.g. running average loss values
@@ -374,7 +395,7 @@ def setup_mlflow_logging(
374
395
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
375
396
log_every_iters : int = 100 ,
376
397
** kwargs : Any
377
- ):
398
+ ) -> MLflowLogger :
378
399
"""Method to setup MLflow logging on trainer and a list of evaluators. Logged metrics are:
379
400
380
401
- Training metrics, e.g. running average loss values
@@ -405,7 +426,7 @@ def setup_neptune_logging(
405
426
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
406
427
log_every_iters : int = 100 ,
407
428
** kwargs : Any
408
- ):
429
+ ) -> NeptuneLogger :
409
430
"""Method to setup Neptune logging on trainer and a list of evaluators. Logged metrics are:
410
431
411
432
- Training metrics, e.g. running average loss values
@@ -436,7 +457,7 @@ def setup_wandb_logging(
436
457
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
437
458
log_every_iters : int = 100 ,
438
459
** kwargs : Any
439
- ):
460
+ ) -> WandBLogger :
440
461
"""Method to setup WandB logging on trainer and a list of evaluators. Logged metrics are:
441
462
442
463
- Training metrics, e.g. running average loss values
@@ -467,7 +488,7 @@ def setup_plx_logging(
467
488
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
468
489
log_every_iters : int = 100 ,
469
490
** kwargs : Any
470
- ):
491
+ ) -> PolyaxonLogger :
471
492
"""Method to setup Polyaxon logging on trainer and a list of evaluators. Logged metrics are:
472
493
473
494
- Training metrics, e.g. running average loss values
@@ -498,7 +519,7 @@ def setup_trains_logging(
498
519
evaluators : Optional [Union [Engine , Dict [str , Engine ]]] = None ,
499
520
log_every_iters : int = 100 ,
500
521
** kwargs : Any
501
- ):
522
+ ) -> TrainsLogger :
502
523
"""Method to setup Trains logging on trainer and a list of evaluators. Logged metrics are:
503
524
504
525
- Training metrics, e.g. running average loss values
@@ -523,8 +544,8 @@ def setup_trains_logging(
523
544
return logger
524
545
525
546
526
- def get_default_score_fn (metric_name : str ):
527
- def wrapper (engine : Engine ):
547
+ def get_default_score_fn (metric_name : str ) -> Any :
548
+ def wrapper (engine : Engine ) -> Any :
528
549
score = engine .state .metrics [metric_name ]
529
550
return score
530
551
@@ -540,7 +561,7 @@ def gen_save_best_models_by_val_score(
540
561
trainer : Optional [Engine ] = None ,
541
562
tag : str = "val" ,
542
563
** kwargs : Any
543
- ):
564
+ ) -> Checkpoint :
544
565
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
545
566
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
546
567
Models with highest metric value will be retained. The logic of how to store objects is delegated to
@@ -570,9 +591,10 @@ def gen_save_best_models_by_val_score(
570
591
if trainer is not None :
571
592
global_step_transform = global_step_from_engine (trainer )
572
593
573
- to_save = models
574
594
if isinstance (models , nn .Module ):
575
- to_save = {"model" : models }
595
+ to_save = {"model" : models } # type: Dict[str, nn.Module]
596
+ else :
597
+ to_save = models
576
598
577
599
best_model_handler = Checkpoint (
578
600
to_save ,
@@ -598,7 +620,7 @@ def save_best_model_by_val_score(
598
620
trainer : Optional [Engine ] = None ,
599
621
tag : str = "val" ,
600
622
** kwargs : Any
601
- ):
623
+ ) -> Checkpoint :
602
624
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
603
625
(named by ``metric_name``) provided by ``evaluator`` (i.e. ``evaluator.state.metrics[metric_name]``).
604
626
Models with highest metric value will be retained.
@@ -629,7 +651,9 @@ def save_best_model_by_val_score(
629
651
)
630
652
631
653
632
- def add_early_stopping_by_val_score (patience : int , evaluator : Engine , trainer : Engine , metric_name : str ):
654
+ def add_early_stopping_by_val_score (
655
+ patience : int , evaluator : Engine , trainer : Engine , metric_name : str
656
+ ) -> EarlyStopping :
633
657
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
634
658
Metric value should increase in order to keep training and not early stop.
635
659
0 commit comments