-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Working ParameterScheduler #1949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
wyli
merged 17 commits into
Project-MONAI:master
from
danieltudosiu:1552_parameter_scheduler
Apr 7, 2021
Merged
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
f4ba155
Working ParameterScheduler
danieltudosiu 732bc81
[MONAI] python code formatting
monai-bot caa6f40
Modified docs & Excluded fro min_tests
danieltudosiu 01064f9
Merge remote-tracking branch 'origin/1552_parameter_scheduler' into 1…
danieltudosiu f73b1a1
Merge branch 'master' into 1552_parameter_scheduler
danieltudosiu a7f76e9
Fixes
danieltudosiu ea05bf2
Merge remote-tracking branch 'origin/1552_parameter_scheduler' into 1…
danieltudosiu 8161439
Fixes
danieltudosiu c5b4732
Deleting scratchpad
danieltudosiu a32bf1e
Fixes
danieltudosiu bc764a5
type fixes
wyli 4b9ab9b
Fixes
danieltudosiu 53a9321
Merge remote-tracking branch 'origin/1552_parameter_scheduler' into 1…
danieltudosiu 60e6738
Fixes
danieltudosiu a6cd950
fixes precision
wyli d59c8b9
Merge branch 'master' into 1552_parameter_scheduler
wyli c1f9537
Merge branch 'master' into 1552_parameter_scheduler
wyli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import logging | ||
from bisect import bisect_right | ||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union | ||
|
||
from monai.utils import exact_version, optional_import | ||
|
||
if TYPE_CHECKING: | ||
from ignite.engine import Engine, Events | ||
else: | ||
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") | ||
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") | ||
|
||
|
||
class ParamSchedulerHandler: | ||
""" | ||
General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or | ||
multistep function. One can also pass Callables to have customized scheduling logic. | ||
|
||
Args: | ||
parameter_setter (Callable): Function that sets the required parameter | ||
value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep') | ||
or Callable for custom logic. | ||
vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator. | ||
epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False. | ||
name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``. | ||
event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
parameter_setter: Callable, | ||
value_calculator: Union[str, Callable], | ||
vc_kwargs: Dict, | ||
epoch_level: bool = False, | ||
name: Optional[str] = None, | ||
event=Events.ITERATION_COMPLETED, | ||
): | ||
self.epoch_level = epoch_level | ||
self.event = event | ||
|
||
self._calculators = { | ||
"linear": self._linear, | ||
"exponential": self._exponential, | ||
"step": self._step, | ||
"multistep": self._multistep, | ||
} | ||
|
||
self._parameter_setter = parameter_setter | ||
self._vc_kwargs = vc_kwargs | ||
self._value_calculator = self._get_value_calculator(value_calculator=value_calculator) | ||
|
||
self.logger = logging.getLogger(name) | ||
self._name = name | ||
|
||
def _get_value_calculator(self, value_calculator): | ||
if isinstance(value_calculator, str): | ||
return self._calculators[value_calculator] | ||
if callable(value_calculator): | ||
return value_calculator | ||
raise ValueError( | ||
f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." | ||
) | ||
|
||
def __call__(self, engine: Engine): | ||
if self.epoch_level: | ||
self._vc_kwargs["current_step"] = engine.state.epoch | ||
else: | ||
self._vc_kwargs["current_step"] = engine.state.iteration | ||
|
||
new_value = self._value_calculator(**self._vc_kwargs) | ||
self._parameter_setter(new_value) | ||
|
||
def attach(self, engine: Engine) -> None: | ||
""" | ||
Args: | ||
engine: Ignite Engine that is used for training. | ||
""" | ||
if self._name is None: | ||
self.logger = engine.logger | ||
engine.add_event_handler(self.event, self) | ||
|
||
@staticmethod | ||
def _linear( | ||
initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int | ||
) -> float: | ||
""" | ||
Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an | ||
additional step_one steps passed. Continues the trend until it reaches max_value. | ||
|
||
Args: | ||
initial_value (float): Starting value of the parameter. | ||
step_constant (int): Step index until parameter's value is kept constant. | ||
step_max_value (int): Step index at which parameter's value becomes max_value. | ||
max_value (float): Max parameter value. | ||
current_step (int): Current step index. | ||
|
||
Returns: | ||
float: new parameter value | ||
""" | ||
if current_step <= step_constant: | ||
delta = 0.0 | ||
elif current_step > step_max_value: | ||
delta = max_value - initial_value | ||
else: | ||
delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant) | ||
|
||
return initial_value + delta | ||
|
||
@staticmethod | ||
def _exponential(initial_value: float, gamma: float, current_step: int) -> float: | ||
""" | ||
Decays the parameter value by gamma every step. | ||
|
||
Based on the closed form of ExponentialLR from Pytorch | ||
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457 | ||
|
||
Args: | ||
initial_value (float): Starting value of the parameter. | ||
gamma (float): Multiplicative factor of parameter value decay. | ||
current_step (int): Current step index. | ||
|
||
Returns: | ||
float: new parameter value | ||
""" | ||
return initial_value * gamma ** current_step | ||
|
||
@staticmethod | ||
def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: | ||
""" | ||
Decays the parameter value by gamma every step_size. | ||
|
||
Based on StepLR from Pytorch. | ||
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377 | ||
|
||
Args: | ||
initial_value (float): Starting value of the parameter. | ||
gamma (float): Multiplicative factor of parameter value decay. | ||
step_size (int): Period of parameter value decay. | ||
current_step (int): Current step index. | ||
|
||
Returns | ||
float: new parameter value | ||
""" | ||
return initial_value * gamma ** (current_step // step_size) | ||
|
||
@staticmethod | ||
def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float: | ||
""" | ||
Decays the parameter value by gamma once the number of steps reaches one of the milestones. | ||
|
||
Based on MultiStepLR from Pytorch. | ||
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424 | ||
|
||
Args: | ||
initial_value (float): Starting value of the parameter. | ||
gamma (float): Multiplicative factor of parameter value decay. | ||
milestones (List[int]): List of step indices. Must be increasing. | ||
current_step (int): Current step index. | ||
|
||
Returns: | ||
float: new parameter value | ||
""" | ||
return initial_value * gamma ** bisect_right(milestones, current_step) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import unittest | ||
|
||
import torch | ||
from ignite.engine import Engine, Events | ||
from torch.nn import Module | ||
|
||
from monai.handlers.parameter_scheduler import ParamSchedulerHandler | ||
|
||
|
||
class ToyNet(Module): | ||
def __init__(self, value): | ||
super(ToyNet, self).__init__() | ||
self.value = value | ||
|
||
def forward(self, input): | ||
return input | ||
|
||
def get_value(self): | ||
return self.value | ||
|
||
def set_value(self, value): | ||
self.value = value | ||
|
||
|
||
class TestHandlerParameterScheduler(unittest.TestCase): | ||
def test_linear_scheduler(self): | ||
# Testing step_constant | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="linear", | ||
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=2) | ||
torch.testing.assert_allclose(net.get_value(), 0) | ||
|
||
# Testing linear increase | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="linear", | ||
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=3) | ||
torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0) | ||
|
||
# Testing max_value | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="linear", | ||
vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=10) | ||
torch.testing.assert_allclose(net.get_value(), 10) | ||
|
||
def test_exponential_scheduler(self): | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="exponential", | ||
vc_kwargs={"initial_value": 10, "gamma": 0.99}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=2) | ||
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) | ||
|
||
def test_step_scheduler(self): | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="step", | ||
vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=10) | ||
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) | ||
|
||
def test_multistep_scheduler(self): | ||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator="multistep", | ||
vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=10) | ||
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) | ||
|
||
def test_custom_scheduler(self): | ||
def custom_logic(initial_value, gamma, current_step): | ||
return initial_value * gamma ** (current_step % 9) | ||
|
||
net = ToyNet(value=-1) | ||
engine = Engine(lambda e, b: None) | ||
ParamSchedulerHandler( | ||
parameter_setter=net.set_value, | ||
value_calculator=custom_logic, | ||
vc_kwargs={"initial_value": 10, "gamma": 0.99}, | ||
epoch_level=True, | ||
event=Events.EPOCH_COMPLETED, | ||
).attach(engine) | ||
engine.run([0] * 8, max_epochs=2) | ||
torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think providing an example of usage would be of a great help for the users.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I completely forgot about the use-case example. I will do another pull request to amend this one.