Skip to content

Commit 7581255

Browse files
Working ParameterScheduler (#1949)
* Working ParameterScheduler Added a new ParameterScheduler handler and the required tests. Signed-off-by: Petru-Daniel Tudosiu <[email protected]>
1 parent d6ed956 commit 7581255

File tree

5 files changed

+293
-0
lines changed

5 files changed

+293
-0
lines changed

docs/source/handlers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ SmartCache handler
111111
.. autoclass:: SmartCacheHandler
112112
:members:
113113

114+
Parameter Scheduler handler
115+
---------------------------
116+
.. autoclass:: ParamSchedulerHandler
117+
:members:
118+
114119
EarlyStop handler
115120
-----------------
116121
.. autoclass:: EarlyStopHandler

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .mean_dice import MeanDice
2222
from .metric_logger import MetricLogger, MetricLoggerKeys
2323
from .metrics_saver import MetricsSaver
24+
from .parameter_scheduler import ParamSchedulerHandler
2425
from .roc_auc import ROCAUC
2526
from .segmentation_saver import SegmentationSaver
2627
from .smartcache_handler import SmartCacheHandler

monai/handlers/parameter_scheduler.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import logging
2+
from bisect import bisect_right
3+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
4+
5+
from monai.utils import exact_version, optional_import
6+
7+
if TYPE_CHECKING:
8+
from ignite.engine import Engine, Events
9+
else:
10+
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
11+
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
12+
13+
14+
class ParamSchedulerHandler:
15+
"""
16+
General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or
17+
multistep function. One can also pass Callables to have customized scheduling logic.
18+
19+
Args:
20+
parameter_setter (Callable): Function that sets the required parameter
21+
value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
22+
or Callable for custom logic.
23+
vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
24+
epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False.
25+
name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
26+
event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
27+
"""
28+
29+
def __init__(
30+
self,
31+
parameter_setter: Callable,
32+
value_calculator: Union[str, Callable],
33+
vc_kwargs: Dict,
34+
epoch_level: bool = False,
35+
name: Optional[str] = None,
36+
event=Events.ITERATION_COMPLETED,
37+
):
38+
self.epoch_level = epoch_level
39+
self.event = event
40+
41+
self._calculators = {
42+
"linear": self._linear,
43+
"exponential": self._exponential,
44+
"step": self._step,
45+
"multistep": self._multistep,
46+
}
47+
48+
self._parameter_setter = parameter_setter
49+
self._vc_kwargs = vc_kwargs
50+
self._value_calculator = self._get_value_calculator(value_calculator=value_calculator)
51+
52+
self.logger = logging.getLogger(name)
53+
self._name = name
54+
55+
def _get_value_calculator(self, value_calculator):
56+
if isinstance(value_calculator, str):
57+
return self._calculators[value_calculator]
58+
if callable(value_calculator):
59+
return value_calculator
60+
raise ValueError(
61+
f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable."
62+
)
63+
64+
def __call__(self, engine: Engine):
65+
if self.epoch_level:
66+
self._vc_kwargs["current_step"] = engine.state.epoch
67+
else:
68+
self._vc_kwargs["current_step"] = engine.state.iteration
69+
70+
new_value = self._value_calculator(**self._vc_kwargs)
71+
self._parameter_setter(new_value)
72+
73+
def attach(self, engine: Engine) -> None:
74+
"""
75+
Args:
76+
engine: Ignite Engine that is used for training.
77+
"""
78+
if self._name is None:
79+
self.logger = engine.logger
80+
engine.add_event_handler(self.event, self)
81+
82+
@staticmethod
83+
def _linear(
84+
initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int
85+
) -> float:
86+
"""
87+
Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an
88+
additional step_one steps passed. Continues the trend until it reaches max_value.
89+
90+
Args:
91+
initial_value (float): Starting value of the parameter.
92+
step_constant (int): Step index until parameter's value is kept constant.
93+
step_max_value (int): Step index at which parameter's value becomes max_value.
94+
max_value (float): Max parameter value.
95+
current_step (int): Current step index.
96+
97+
Returns:
98+
float: new parameter value
99+
"""
100+
if current_step <= step_constant:
101+
delta = 0.0
102+
elif current_step > step_max_value:
103+
delta = max_value - initial_value
104+
else:
105+
delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant)
106+
107+
return initial_value + delta
108+
109+
@staticmethod
110+
def _exponential(initial_value: float, gamma: float, current_step: int) -> float:
111+
"""
112+
Decays the parameter value by gamma every step.
113+
114+
Based on the closed form of ExponentialLR from Pytorch
115+
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457
116+
117+
Args:
118+
initial_value (float): Starting value of the parameter.
119+
gamma (float): Multiplicative factor of parameter value decay.
120+
current_step (int): Current step index.
121+
122+
Returns:
123+
float: new parameter value
124+
"""
125+
return initial_value * gamma ** current_step
126+
127+
@staticmethod
128+
def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float:
129+
"""
130+
Decays the parameter value by gamma every step_size.
131+
132+
Based on StepLR from Pytorch.
133+
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377
134+
135+
Args:
136+
initial_value (float): Starting value of the parameter.
137+
gamma (float): Multiplicative factor of parameter value decay.
138+
step_size (int): Period of parameter value decay.
139+
current_step (int): Current step index.
140+
141+
Returns
142+
float: new parameter value
143+
"""
144+
return initial_value * gamma ** (current_step // step_size)
145+
146+
@staticmethod
147+
def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float:
148+
"""
149+
Decays the parameter value by gamma once the number of steps reaches one of the milestones.
150+
151+
Based on MultiStepLR from Pytorch.
152+
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424
153+
154+
Args:
155+
initial_value (float): Starting value of the parameter.
156+
gamma (float): Multiplicative factor of parameter value decay.
157+
milestones (List[int]): List of step indices. Must be increasing.
158+
current_step (int): Current step index.
159+
160+
Returns:
161+
float: new parameter value
162+
"""
163+
return initial_value * gamma ** bisect_right(milestones, current_step)

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def run_testsuit():
4747
"test_handler_prob_map_producer",
4848
"test_handler_rocauc",
4949
"test_handler_rocauc_dist",
50+
"test_handler_parameter_scheduler",
5051
"test_handler_segmentation_saver",
5152
"test_handler_smartcache",
5253
"test_handler_stats",
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import unittest
2+
3+
import torch
4+
from ignite.engine import Engine, Events
5+
from torch.nn import Module
6+
7+
from monai.handlers.parameter_scheduler import ParamSchedulerHandler
8+
9+
10+
class ToyNet(Module):
11+
def __init__(self, value):
12+
super(ToyNet, self).__init__()
13+
self.value = value
14+
15+
def forward(self, input):
16+
return input
17+
18+
def get_value(self):
19+
return self.value
20+
21+
def set_value(self, value):
22+
self.value = value
23+
24+
25+
class TestHandlerParameterScheduler(unittest.TestCase):
26+
def test_linear_scheduler(self):
27+
# Testing step_constant
28+
net = ToyNet(value=-1)
29+
engine = Engine(lambda e, b: None)
30+
ParamSchedulerHandler(
31+
parameter_setter=net.set_value,
32+
value_calculator="linear",
33+
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
34+
epoch_level=True,
35+
event=Events.EPOCH_COMPLETED,
36+
).attach(engine)
37+
engine.run([0] * 8, max_epochs=2)
38+
torch.testing.assert_allclose(net.get_value(), 0)
39+
40+
# Testing linear increase
41+
net = ToyNet(value=-1)
42+
engine = Engine(lambda e, b: None)
43+
ParamSchedulerHandler(
44+
parameter_setter=net.set_value,
45+
value_calculator="linear",
46+
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
47+
epoch_level=True,
48+
event=Events.EPOCH_COMPLETED,
49+
).attach(engine)
50+
engine.run([0] * 8, max_epochs=3)
51+
torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0)
52+
53+
# Testing max_value
54+
net = ToyNet(value=-1)
55+
engine = Engine(lambda e, b: None)
56+
ParamSchedulerHandler(
57+
parameter_setter=net.set_value,
58+
value_calculator="linear",
59+
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10},
60+
epoch_level=True,
61+
event=Events.EPOCH_COMPLETED,
62+
).attach(engine)
63+
engine.run([0] * 8, max_epochs=10)
64+
torch.testing.assert_allclose(net.get_value(), 10)
65+
66+
def test_exponential_scheduler(self):
67+
net = ToyNet(value=-1)
68+
engine = Engine(lambda e, b: None)
69+
ParamSchedulerHandler(
70+
parameter_setter=net.set_value,
71+
value_calculator="exponential",
72+
vc_kwargs={"initial_value": 10, "gamma": 0.99},
73+
epoch_level=True,
74+
event=Events.EPOCH_COMPLETED,
75+
).attach(engine)
76+
engine.run([0] * 8, max_epochs=2)
77+
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
78+
79+
def test_step_scheduler(self):
80+
net = ToyNet(value=-1)
81+
engine = Engine(lambda e, b: None)
82+
ParamSchedulerHandler(
83+
parameter_setter=net.set_value,
84+
value_calculator="step",
85+
vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5},
86+
epoch_level=True,
87+
event=Events.EPOCH_COMPLETED,
88+
).attach(engine)
89+
engine.run([0] * 8, max_epochs=10)
90+
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
91+
92+
def test_multistep_scheduler(self):
93+
net = ToyNet(value=-1)
94+
engine = Engine(lambda e, b: None)
95+
ParamSchedulerHandler(
96+
parameter_setter=net.set_value,
97+
value_calculator="multistep",
98+
vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]},
99+
epoch_level=True,
100+
event=Events.EPOCH_COMPLETED,
101+
).attach(engine)
102+
engine.run([0] * 8, max_epochs=10)
103+
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
104+
105+
def test_custom_scheduler(self):
106+
def custom_logic(initial_value, gamma, current_step):
107+
return initial_value * gamma ** (current_step % 9)
108+
109+
net = ToyNet(value=-1)
110+
engine = Engine(lambda e, b: None)
111+
ParamSchedulerHandler(
112+
parameter_setter=net.set_value,
113+
value_calculator=custom_logic,
114+
vc_kwargs={"initial_value": 10, "gamma": 0.99},
115+
epoch_level=True,
116+
event=Events.EPOCH_COMPLETED,
117+
).attach(engine)
118+
engine.run([0] * 8, max_epochs=2)
119+
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99)
120+
121+
122+
if __name__ == "__main__":
123+
unittest.main()

0 commit comments

Comments
 (0)