Skip to content

Working ParameterScheduler #1949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ SmartCache handler
.. autoclass:: SmartCacheHandler
:members:

Parameter Scheduler handler
---------------------------
.. autoclass:: ParamSchedulerHandler
:members:

EarlyStop handler
-----------------
.. autoclass:: EarlyStopHandler
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .mean_dice import MeanDice
from .metric_logger import MetricLogger, MetricLoggerKeys
from .metrics_saver import MetricsSaver
from .parameter_scheduler import ParamSchedulerHandler
from .roc_auc import ROCAUC
from .segmentation_saver import SegmentationSaver
from .smartcache_handler import SmartCacheHandler
Expand Down
163 changes: 163 additions & 0 deletions monai/handlers/parameter_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import logging
from bisect import bisect_right
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

from monai.utils import exact_version, optional_import

if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")


class ParamSchedulerHandler:
"""
General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think providing an example of usage would be of a great help for the users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely forgot about the use-case example. I will do another pull request to amend this one.

multistep function. One can also pass Callables to have customized scheduling logic.

Args:
parameter_setter (Callable): Function that sets the required parameter
value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
or Callable for custom logic.
vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False.
name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
"""

def __init__(
self,
parameter_setter: Callable,
value_calculator: Union[str, Callable],
vc_kwargs: Dict,
epoch_level: bool = False,
name: Optional[str] = None,
event=Events.ITERATION_COMPLETED,
):
self.epoch_level = epoch_level
self.event = event

self._calculators = {
"linear": self._linear,
"exponential": self._exponential,
"step": self._step,
"multistep": self._multistep,
}

self._parameter_setter = parameter_setter
self._vc_kwargs = vc_kwargs
self._value_calculator = self._get_value_calculator(value_calculator=value_calculator)

self.logger = logging.getLogger(name)
self._name = name

def _get_value_calculator(self, value_calculator):
if isinstance(value_calculator, str):
return self._calculators[value_calculator]
if callable(value_calculator):
return value_calculator
raise ValueError(
f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable."
)

def __call__(self, engine: Engine):
if self.epoch_level:
self._vc_kwargs["current_step"] = engine.state.epoch
else:
self._vc_kwargs["current_step"] = engine.state.iteration

new_value = self._value_calculator(**self._vc_kwargs)
self._parameter_setter(new_value)

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine that is used for training.
"""
if self._name is None:
self.logger = engine.logger
engine.add_event_handler(self.event, self)

@staticmethod
def _linear(
initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int
) -> float:
"""
Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an
additional step_one steps passed. Continues the trend until it reaches max_value.

Args:
initial_value (float): Starting value of the parameter.
step_constant (int): Step index until parameter's value is kept constant.
step_max_value (int): Step index at which parameter's value becomes max_value.
max_value (float): Max parameter value.
current_step (int): Current step index.

Returns:
float: new parameter value
"""
if current_step <= step_constant:
delta = 0.0
elif current_step > step_max_value:
delta = max_value - initial_value
else:
delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant)

return initial_value + delta

@staticmethod
def _exponential(initial_value: float, gamma: float, current_step: int) -> float:
"""
Decays the parameter value by gamma every step.

Based on the closed form of ExponentialLR from Pytorch
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457

Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
current_step (int): Current step index.

Returns:
float: new parameter value
"""
return initial_value * gamma ** current_step

@staticmethod
def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float:
"""
Decays the parameter value by gamma every step_size.

Based on StepLR from Pytorch.
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377

Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
step_size (int): Period of parameter value decay.
current_step (int): Current step index.

Returns
float: new parameter value
"""
return initial_value * gamma ** (current_step // step_size)

@staticmethod
def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float:
"""
Decays the parameter value by gamma once the number of steps reaches one of the milestones.

Based on MultiStepLR from Pytorch.
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424

Args:
initial_value (float): Starting value of the parameter.
gamma (float): Multiplicative factor of parameter value decay.
milestones (List[int]): List of step indices. Must be increasing.
current_step (int): Current step index.

Returns:
float: new parameter value
"""
return initial_value * gamma ** bisect_right(milestones, current_step)
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def run_testsuit():
"test_handler_prob_map_producer",
"test_handler_rocauc",
"test_handler_rocauc_dist",
"test_handler_parameter_scheduler",
"test_handler_segmentation_saver",
"test_handler_smartcache",
"test_handler_stats",
Expand Down
123 changes: 123 additions & 0 deletions tests/test_handler_parameter_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import unittest

import torch
from ignite.engine import Engine, Events
from torch.nn import Module

from monai.handlers.parameter_scheduler import ParamSchedulerHandler


class ToyNet(Module):
def __init__(self, value):
super(ToyNet, self).__init__()
self.value = value

def forward(self, input):
return input

def get_value(self):
return self.value

def set_value(self, value):
self.value = value


class TestHandlerParameterScheduler(unittest.TestCase):
def test_linear_scheduler(self):
# Testing step_constant
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="linear",
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
torch.testing.assert_allclose(net.get_value(), 0)

# Testing linear increase
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="linear",
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=3)
torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)

# Testing max_value
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="linear",
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
torch.testing.assert_allclose(net.get_value(), 10)

def test_exponential_scheduler(self):
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="exponential",
vc_kwargs={"initial_value": 10, "gamma": 0.99},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)

def test_step_scheduler(self):
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="step",
vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)

def test_multistep_scheduler(self):
net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator="multistep",
vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=10)
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)

def test_custom_scheduler(self):
def custom_logic(initial_value, gamma, current_step):
return initial_value * gamma ** (current_step % 9)

net = ToyNet(value=-1)
engine = Engine(lambda e, b: None)
ParamSchedulerHandler(
parameter_setter=net.set_value,
value_calculator=custom_logic,
vc_kwargs={"initial_value": 10, "gamma": 0.99},
epoch_level=True,
event=Events.EPOCH_COMPLETED,
).attach(engine)
engine.run([0] * 8, max_epochs=2)
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)


if __name__ == "__main__":
unittest.main()