Skip to content

[2] [contrib/handlers] setup typing in contrib part of the library #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 6, 2020
32 changes: 19 additions & 13 deletions ignite/contrib/handlers/base_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import numbers
import warnings
from abc import ABCMeta, abstractmethod
from collections.abc import Sequence
from typing import Any, Mapping
from typing import Any, Callable, List, Optional, Sequence, Union

import torch
import torch.nn as nn
from torch.optim import Optimizer

from ignite.engine import Engine, State
Expand All @@ -21,7 +21,7 @@ class BaseOptimizerParamsHandler(BaseHandler):
Base handler for logging optimizer parameters
"""

def __init__(self, optimizer, param_name="lr", tag=None):
def __init__(self, optimizer: Optimizer, param_name: str = "lr", tag: Optional[str] = None):
if not (
isinstance(optimizer, Optimizer)
or (hasattr(optimizer, "param_groups") and isinstance(optimizer.param_groups, Sequence))
Expand All @@ -41,7 +41,13 @@ class BaseOutputHandler(BaseHandler):
Helper handler to log engine's output and/or metrics
"""

def __init__(self, tag, metric_names=None, output_transform=None, global_step_transform=None):
def __init__(
self,
tag: str,
metric_names: Optional[Union[str, List[str]]] = None,
output_transform: Optional[Callable] = None,
global_step_transform: Optional[Callable] = None,
):

if metric_names is not None:
if not (isinstance(metric_names, list) or (isinstance(metric_names, str) and metric_names == "all")):
Expand Down Expand Up @@ -70,7 +76,7 @@ def global_step_transform(engine, event_name):
self.output_transform = output_transform
self.global_step_transform = global_step_transform

def _setup_output_metrics(self, engine):
def _setup_output_metrics(self, engine: Engine):
"""Helper method to setup metrics to log
"""
metrics = {}
Expand Down Expand Up @@ -102,14 +108,14 @@ class BaseWeightsScalarHandler(BaseHandler):
Helper handler to log model's weights as scalars.
"""

def __init__(self, model, reduction=torch.norm, tag=None):
def __init__(self, model: nn.Module, reduction: Callable = torch.norm, tag: Optional[str] = None):
if not isinstance(model, torch.nn.Module):
raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model)))

if not callable(reduction):
raise TypeError("Argument reduction should be callable, " "but given {}".format(type(reduction)))

def _is_0D_tensor(t):
def _is_0D_tensor(t: torch.Tensor):
return isinstance(t, torch.Tensor) and t.ndimension() == 0

# Test reduction function on a tensor
Expand All @@ -127,7 +133,7 @@ class BaseWeightsHistHandler(BaseHandler):
Helper handler to log model's weights as histograms.
"""

def __init__(self, model, tag=None):
def __init__(self, model: nn.Module, tag: Optional[str] = None):
if not isinstance(model, torch.nn.Module):
raise TypeError("Argument model should be of type torch.nn.Module, " "but given {}".format(type(model)))

Expand All @@ -141,7 +147,7 @@ class BaseLogger(metaclass=ABCMeta):

"""

def attach(self, engine, log_handler, event_name):
def attach(self, engine: Engine, log_handler: Callable, event_name: Any):
"""Attach the logger to the engine and execute `log_handler` function at `event_name` events.

Args:
Expand All @@ -161,7 +167,7 @@ def attach(self, engine, log_handler, event_name):

return engine.add_event_handler(event_name, log_handler, self, name)

def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Mapping):
def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
"""Shortcut method to attach `OutputHandler` to the logger.

Args:
Expand All @@ -177,7 +183,7 @@ def attach_output_handler(self, engine: Engine, event_name: Any, *args: Any, **k
"""
return self.attach(engine, self._create_output_handler(*args, **kwargs), event_name=event_name)

def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Mapping):
def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any, **kwargs: Any):
"""Shortcut method to attach `OptimizerParamsHandler` to the logger.

Args:
Expand All @@ -194,11 +200,11 @@ def attach_opt_params_handler(self, engine: Engine, event_name: Any, *args: Any,
self.attach(engine, self._create_opt_params_handler(*args, **kwargs), event_name=event_name)

@abstractmethod
def _create_output_handler(self, engine, *args, **kwargs):
def _create_output_handler(self, engine: Engine, *args: Any, **kwargs: Any):
pass

@abstractmethod
def _create_opt_params_handler(self, *args, **kwargs):
def _create_opt_params_handler(self, *args: Any, **kwargs: Any):
pass

def __enter__(self):
Expand Down
43 changes: 27 additions & 16 deletions ignite/contrib/handlers/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import logging
import tempfile
import warnings
from collections.abc import Mapping
from pathlib import Path
from typing import Callable, Mapping, Optional

import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from ignite.contrib.handlers.param_scheduler import LRScheduler, PiecewiseLinear
Expand Down Expand Up @@ -77,7 +78,17 @@ def __init__(self):
self._lr_schedule = None
self.logger = logging.getLogger(__name__)

def _run(self, trainer, optimizer, output_transform, num_iter, end_lr, step_mode, smooth_f, diverge_th):
def _run(
self,
trainer: Engine,
optimizer: Optimizer,
output_transform: Callable,
num_iter: int,
end_lr: float,
step_mode: str,
smooth_f: float,
diverge_th: float,
):

self._history = {"lr": [], "loss": []}
self._best_loss = None
Expand Down Expand Up @@ -116,13 +127,13 @@ def _run(self, trainer, optimizer, output_transform, num_iter, end_lr, step_mode
if not trainer.has_event_handler(self._lr_schedule):
trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)

def _reset(self, trainer):
def _reset(self, trainer: Engine):
self.logger.debug("Completed LR finder run")
trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED)
trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED)

def _log_lr_and_loss(self, trainer, output_transform, smooth_f, diverge_th):
def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float):
output = trainer.state.output
loss = output_transform(output)
lr = self._lr_schedule.get_param()
Expand All @@ -142,7 +153,7 @@ def _log_lr_and_loss(self, trainer, output_transform, smooth_f, diverge_th):
self.logger.info("Stopping early, the loss has diverged")
trainer.terminate()

def _reached_num_iterations(self, trainer, num_iter):
def _reached_num_iterations(self, trainer: Engine, num_iter: int):
if trainer.state.iteration > num_iter:
trainer.terminate()

Expand All @@ -154,7 +165,7 @@ def _warning(self, _):
UserWarning,
)

def _detach(self, trainer):
def _detach(self, trainer: Engine):
"""
Detaches lr_finder from trainer.

Expand All @@ -175,7 +186,7 @@ def get_results(self):
"""
return self._history

def plot(self, skip_start=10, skip_end=5, log_lr=True):
def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
"""Plots the learning rate range test.

This method requires `matplotlib` package to be installed:
Expand Down Expand Up @@ -242,14 +253,14 @@ def lr_suggestion(self):
@contextlib.contextmanager
def attach(
self,
trainer,
to_save,
output_transform=lambda output: output,
num_iter=None,
end_lr=10.0,
step_mode="exp",
smooth_f=0.05,
diverge_th=5.0,
trainer: Engine,
to_save: Mapping,
output_transform: Callable = lambda output: output,
num_iter: Optional[int] = None,
end_lr: float = 10.0,
step_mode: str = "exp",
smooth_f: float = 0.05,
diverge_th: float = 5.0,
):
"""Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

Expand Down Expand Up @@ -361,7 +372,7 @@ class _ExponentialLR(_LRScheduler):

"""

def __init__(self, optimizer, end_lr, num_iter, last_epoch=-1):
def __init__(self, optimizer: Optimizer, end_lr: float, num_iter: int, last_epoch: int = -1):
self.end_lr = end_lr
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)
Expand Down
Loading