From f4ba1555987a40d50b0d2df60a30d79fd21cae15 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 5 Apr 2021 18:12:12 +0100 Subject: [PATCH 01/11] Working ParameterScheduler Added a new ParameterScheduler handler and the required tests. Signed-off-by: Petru-Daniel Tudosiu --- monai/handlers/parameter_scheduler.py | 162 ++++++++++++++++++++++ tests/test_handler_parameter_scheduler.py | 119 ++++++++++++++++ 2 files changed, 281 insertions(+) create mode 100644 monai/handlers/parameter_scheduler.py create mode 100644 tests/test_handler_parameter_scheduler.py diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py new file mode 100644 index 0000000000..7742d0d926 --- /dev/null +++ b/monai/handlers/parameter_scheduler.py @@ -0,0 +1,162 @@ +import logging +from bisect import bisect_right +from typing import Callable, Union, Dict, Optional, List + +from ignite.engine import Engine, Events, EventEnum + + +class ParamSchedulerHandler: + """ + General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or + multistep function. One can also pass Callables to have customized scheduling logic. + + Args: + parameter_setter (Callable): Function that sets the required parameter + value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep') + or Callable for custom logic. + vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator. + epoch_level (bool): Whether the the step is based on epoch or iteration. Defaults to False. + name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``. + event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED. + """ + + def __init__( + self, + parameter_setter: Callable, + value_calculator: Union[str, Callable], + vc_kwargs: Dict, + epoch_level: bool = False, + name: Optional[str] = None, + event: EventEnum = Events.ITERATION_COMPLETED, + ): + self.epoch_level = epoch_level + self.event = event + + self._calculators = { + "linear": self._linear, + "exponential": self._exponential, + "step": self._step, + "multistep": self._multistep, + } + + self._parameter_setter = parameter_setter + self._vc_kwargs = vc_kwargs + self._value_calculator = self._get_value_calculator(value_calculator=value_calculator) + + self.logger = logging.getLogger(name) + self._name = name + + def _get_value_calculator(self, value_calculator: Union[str, Callable]): + if isinstance(value_calculator, str): + return self._calculators[value_calculator] + elif isinstance(value_calculator, Callable): + return value_calculator + else: + raise ValueError( + f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." + ) + + def __call__(self, engine: Engine): + if self.epoch_level: + self._vc_kwargs["current_step"] = engine.state.epoch + else: + self._vc_kwargs["current_step"] = engine.state.iteration + + new_value = self._value_calculator(**self._vc_kwargs) + self._parameter_setter(new_value) + + def attach(self, engine: Engine) -> None: + """ + Args: + engine: Ignite Engine that is used for training. + """ + if self._name is None: + self.logger = engine.logger + engine.add_event_handler(self.event, self) + + @staticmethod + def _linear( + initial_value: float, step_constant: int, step_max_value: int, max_value: float, current_step: int + ) -> float: + """ + Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an + additional step_one steps passed. Continues the trend until it reaches max_value. + + Args: + initial_value (float): Starting value of the parameter. + step_constant (int): Step index until parameter's value is kept constant. + step_max_value (int): Step index at which parameter's value becomes max_value. + max_value (float): Max parameter value. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + if current_step <= step_constant: + delta = 0 + elif current_step > step_max_value: + delta = max_value - initial_value + else: + delta = ( + (max_value - initial_value) + / (step_max_value - step_constant) + * (current_step - step_constant) + ) + + return initial_value + delta + + @staticmethod + def _exponential(initial_value: float, gamma: float, current_step: int) -> float: + """ + Decays the parameter value by gamma every step. + + Based on the closed form of ExponentialLR from Pytorch + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L457 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + return initial_value * gamma ** current_step + + @staticmethod + def _step(initial_value: float, gamma: float, step_size: int, current_step: int) -> float: + """ + Decays the parameter value by gamma every step_size. + + Based on StepLR from Pytorch. + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L377 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + step_size (int): Period of parameter value decay. + current_step (int): Current step index. + + Returns + float: new parameter value + """ + return initial_value * gamma ** (current_step // step_size) + + @staticmethod + def _multistep(initial_value: float, gamma: float, milestones: List[int], current_step: int) -> float: + """ + Decays the parameter value by gamma once the number of steps reaches one of the milestones. + + Based on MultiStepLR from Pytorch. + https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py#L424 + + Args: + initial_value (float): Starting value of the parameter. + gamma (float): Multiplicative factor of parameter value decay. + milestones (List[int]): List of step indices. Must be increasing. + current_step (int): Current step index. + + Returns: + float: new parameter value + """ + return initial_value * gamma ** bisect_right(milestones, current_step) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py new file mode 100644 index 0000000000..03ef7304b7 --- /dev/null +++ b/tests/test_handler_parameter_scheduler.py @@ -0,0 +1,119 @@ +import unittest +import torch +from torch.nn import Module +from ignite.engine import Engine, Events + +from monai.handlers.parameter_scheduler import ParamSchedulerHandler +from ignite.engine import Events + + +class ToyNet(Module): + def __init__(self, value): + super(ToyNet, self).__init__() + self.value = value + + def forward(self, input): + return input + + def get_value(self): + return self.value + + def set_value(self, value): + self.value = value + + +class TestHandlerParameterScheduler(unittest.TestCase): + def test_linear_scheduler(self): + # Testing step_constant + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 0) + + # Testing linear increase + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 3.33333333) + + # Testing max_value + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="linear", + vc_kwargs={"initial_value": 0, "step_constant": 2, "step_max_value": 5, "max_value": 10}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10) + + def test_exponential_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="exponential", + vc_kwargs={"initial_value": 10, "gamma": 0.99}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_step_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="step", + vc_kwargs={"initial_value": 10, "gamma": 0.99, "step_size": 5}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_multistep_scheduler(self): + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator="multistep", + vc_kwargs={"initial_value": 10, "gamma": 0.99, "milestones": [3, 6]}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=10) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + def test_custom_scheduler(self): + def custom_logic(initial_value, gamma, current_step): + return initial_value * gamma ** (current_step % 9) + + net = ToyNet(value=-1) + engine = Engine(lambda e, b: None) + ParamSchedulerHandler( + parameter_setter=net.set_value, + value_calculator=custom_logic, + vc_kwargs={"initial_value": 10, "gamma": 0.99}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, + ).attach(engine) + engine.run([0] * 8, max_epochs=2) + torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) From 732bc812ac3d8ad75a7565f368085a6529a8bfea Mon Sep 17 00:00:00 2001 From: monai-bot Date: Mon, 5 Apr 2021 17:23:56 +0000 Subject: [PATCH 02/11] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/parameter_scheduler.py | 10 +++------- tests/test_handler_parameter_scheduler.py | 4 ++-- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index 7742d0d926..96120a25fe 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -1,8 +1,8 @@ import logging from bisect import bisect_right -from typing import Callable, Union, Dict, Optional, List +from typing import Callable, Dict, List, Optional, Union -from ignite.engine import Engine, Events, EventEnum +from ignite.engine import Engine, EventEnum, Events class ParamSchedulerHandler: @@ -97,11 +97,7 @@ def _linear( elif current_step > step_max_value: delta = max_value - initial_value else: - delta = ( - (max_value - initial_value) - / (step_max_value - step_constant) - * (current_step - step_constant) - ) + delta = (max_value - initial_value) / (step_max_value - step_constant) * (current_step - step_constant) return initial_value + delta diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 03ef7304b7..3cfdf1ebac 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -1,10 +1,10 @@ import unittest + import torch -from torch.nn import Module from ignite.engine import Engine, Events +from torch.nn import Module from monai.handlers.parameter_scheduler import ParamSchedulerHandler -from ignite.engine import Events class ToyNet(Module): From caa6f404465337030007c9e9cab1bfe77442b30c Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 5 Apr 2021 20:05:36 +0100 Subject: [PATCH 03/11] Modified docs & Excluded fro min_tests Signed-off-by: Petru-Daniel Tudosiu --- docs/source/handlers.rst | 5 +++++ tests/min_tests.py | 1 + 2 files changed, 6 insertions(+) diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index a629b28b27..0b09379af7 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -110,3 +110,8 @@ SmartCache handler ------------------ .. autoclass:: SmartCacheHandler :members: + +Parameter Scheduler handler +------------------------ +.. autoclass:: ParamSchedulerHandler + :members: diff --git a/tests/min_tests.py b/tests/min_tests.py index 83c1ceea9f..6bba16fdfb 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -45,6 +45,7 @@ def run_testsuit(): "test_handler_mean_dice", "test_handler_rocauc", "test_handler_rocauc_dist", + "test_handler_parameter_scheduler", "test_handler_segmentation_saver", "test_handler_smartcache", "test_handler_stats", From a7f76e9339744d6f45c4235196c17a1a9d4446c5 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 5 Apr 2021 21:32:20 +0100 Subject: [PATCH 04/11] Fixes Signed-off-by: Petru-Daniel Tudosiu --- monai/handlers/__init__.py | 1 + monai/handlers/parameter_scheduler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 5669e8a9ee..c53ca5e9cf 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -21,6 +21,7 @@ from .metrics_saver import MetricsSaver from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver +from .parameter_scheduler import ParamSchedulerHandler from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index 96120a25fe..ce0062f71d 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -93,7 +93,7 @@ def _linear( float: new parameter value """ if current_step <= step_constant: - delta = 0 + delta = 0.0 elif current_step > step_max_value: delta = max_value - initial_value else: From 8161439fac5583262353dc288dbbc73d34800f5b Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 5 Apr 2021 21:42:34 +0100 Subject: [PATCH 05/11] Fixes Signed-off-by: Petru-Daniel Tudosiu --- docs/source/handlers.rst | 2 +- monai/handlers/__init__.py | 2 +- scratchpad.py | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 2 deletions(-) create mode 100644 scratchpad.py diff --git a/docs/source/handlers.rst b/docs/source/handlers.rst index da01227551..9030fa3ced 100644 --- a/docs/source/handlers.rst +++ b/docs/source/handlers.rst @@ -112,7 +112,7 @@ SmartCache handler :members: Parameter Scheduler handler ------------------------- +--------------------------- .. autoclass:: ParamSchedulerHandler :members: diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 450d2561cf..f88531ea8e 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -21,9 +21,9 @@ from .mean_dice import MeanDice from .metric_logger import MetricLogger, MetricLoggerKeys from .metrics_saver import MetricsSaver +from .parameter_scheduler import ParamSchedulerHandler from .roc_auc import ROCAUC from .segmentation_saver import SegmentationSaver -from .parameter_scheduler import ParamSchedulerHandler from .smartcache_handler import SmartCacheHandler from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance diff --git a/scratchpad.py b/scratchpad.py new file mode 100644 index 0000000000..4aa37d8a37 --- /dev/null +++ b/scratchpad.py @@ -0,0 +1,36 @@ +import unittest + +from torch.nn import Module +from ignite.engine import Engine, Events + +from monai.handlers.parameter_scheduler import ParamSchedulerHandler +from ignite.engine import Events + + +class ToyNet(Module): + def __init__(self, value): + super(ToyNet, self).__init__() + self.value = value + + def forward(self, input): + return input + + def get_value(self): + return self.value + + def set_value(self, value): + self.value = value + + +net = ToyNet(value=1) +engine = Engine(lambda e, b: None) +ParamSchedulerHandler( + parameter_setter=net.get_value, + value_calculator="linear", + vc_kwargs={"initial_value": 1, "step_constant": 2, "step_max_value": 4, "max_value": 5}, + epoch_level=True, + event=Events.EPOCH_COMPLETED, +).attach(engine) +engine.run([0] * 8, max_epochs=5) + +print(net.get_value()) From c5b47325068f34ea2cfaf448b730b3b4ab1c6c16 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Mon, 5 Apr 2021 21:56:24 +0100 Subject: [PATCH 06/11] Deleting scratchpad Signed-off-by: Petru-Daniel Tudosiu --- scratchpad.py | 36 ------------------------------------ 1 file changed, 36 deletions(-) delete mode 100644 scratchpad.py diff --git a/scratchpad.py b/scratchpad.py deleted file mode 100644 index 4aa37d8a37..0000000000 --- a/scratchpad.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest - -from torch.nn import Module -from ignite.engine import Engine, Events - -from monai.handlers.parameter_scheduler import ParamSchedulerHandler -from ignite.engine import Events - - -class ToyNet(Module): - def __init__(self, value): - super(ToyNet, self).__init__() - self.value = value - - def forward(self, input): - return input - - def get_value(self): - return self.value - - def set_value(self, value): - self.value = value - - -net = ToyNet(value=1) -engine = Engine(lambda e, b: None) -ParamSchedulerHandler( - parameter_setter=net.get_value, - value_calculator="linear", - vc_kwargs={"initial_value": 1, "step_constant": 2, "step_max_value": 4, "max_value": 5}, - epoch_level=True, - event=Events.EPOCH_COMPLETED, -).attach(engine) -engine.run([0] * 8, max_epochs=5) - -print(net.get_value()) From a32bf1e937d0d5c4c6f1628c841b703227bdb5e5 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Tue, 6 Apr 2021 08:39:52 +0100 Subject: [PATCH 07/11] Fixes Signed-off-by: Petru-Daniel Tudosiu --- tests/test_handler_parameter_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 3cfdf1ebac..4e4a820201 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -48,7 +48,7 @@ def test_linear_scheduler(self): event=Events.EPOCH_COMPLETED, ).attach(engine) engine.run([0] * 8, max_epochs=2) - torch.testing.assert_allclose(net.get_value(), 3.33333333) + torch.testing.assert_allclose(net.get_value(), 3.33, atol=0.001) # Testing max_value net = ToyNet(value=-1) From bc764a57941a1c32f87ea154c1bf27563b7c174d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 6 Apr 2021 09:02:12 +0100 Subject: [PATCH 08/11] type fixes Signed-off-by: Wenqi Li --- monai/handlers/parameter_scheduler.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py index ce0062f71d..2aa0224a5a 100644 --- a/monai/handlers/parameter_scheduler.py +++ b/monai/handlers/parameter_scheduler.py @@ -1,8 +1,14 @@ import logging from bisect import bisect_right -from typing import Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from ignite.engine import Engine, EventEnum, Events +from monai.utils import exact_version, optional_import + +if TYPE_CHECKING: + from ignite.engine import Engine, Events +else: + Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine") + Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events") class ParamSchedulerHandler: @@ -27,7 +33,7 @@ def __init__( vc_kwargs: Dict, epoch_level: bool = False, name: Optional[str] = None, - event: EventEnum = Events.ITERATION_COMPLETED, + event=Events.ITERATION_COMPLETED, ): self.epoch_level = epoch_level self.event = event @@ -46,15 +52,14 @@ def __init__( self.logger = logging.getLogger(name) self._name = name - def _get_value_calculator(self, value_calculator: Union[str, Callable]): + def _get_value_calculator(self, value_calculator): if isinstance(value_calculator, str): return self._calculators[value_calculator] - elif isinstance(value_calculator, Callable): + if callable(value_calculator): return value_calculator - else: - raise ValueError( - f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." - ) + raise ValueError( + f"value_calculator must be either a string from {list(self._calculators.keys())} or a Callable." + ) def __call__(self, engine: Engine): if self.epoch_level: From 4b9ab9b2eddfd929721eb69c0ed52bfe90e217ab Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Tue, 6 Apr 2021 17:31:48 +0100 Subject: [PATCH 09/11] Fixes Signed-off-by: Petru-Daniel Tudosiu --- tests/test_handler_parameter_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 4e4a820201..ff5430f7d3 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -48,7 +48,7 @@ def test_linear_scheduler(self): event=Events.EPOCH_COMPLETED, ).attach(engine) engine.run([0] * 8, max_epochs=2) - torch.testing.assert_allclose(net.get_value(), 3.33, atol=0.001) + torch.testing.assert_allclose(net.get_value(), 3.33, atol=0.001, rtol=0.0) # Testing max_value net = ToyNet(value=-1) From 60e67389bc7dd5c5117fcffee076277faa01bb87 Mon Sep 17 00:00:00 2001 From: Petru-Daniel Tudosiu Date: Tue, 6 Apr 2021 18:10:00 +0100 Subject: [PATCH 10/11] Fixes Signed-off-by: Petru-Daniel Tudosiu --- tests/test_handler_parameter_scheduler.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index ff5430f7d3..414ce5a6a4 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -47,7 +47,7 @@ def test_linear_scheduler(self): epoch_level=True, event=Events.EPOCH_COMPLETED, ).attach(engine) - engine.run([0] * 8, max_epochs=2) + engine.run([0] * 8, max_epochs=3) torch.testing.assert_allclose(net.get_value(), 3.33, atol=0.001, rtol=0.0) # Testing max_value @@ -117,3 +117,7 @@ def custom_logic(initial_value, gamma, current_step): ).attach(engine) engine.run([0] * 8, max_epochs=2) torch.testing.assert_allclose(net.get_value(), 10 * 0.99 * 0.99) + + +if __name__ == "__main__": + unittest.main() From a6cd950e81c7914fe3de21cb154dfd8ab9fe70b9 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 7 Apr 2021 09:20:16 +0100 Subject: [PATCH 11/11] fixes precision Signed-off-by: Wenqi Li --- tests/test_handler_parameter_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_handler_parameter_scheduler.py b/tests/test_handler_parameter_scheduler.py index 414ce5a6a4..5b3e845ace 100644 --- a/tests/test_handler_parameter_scheduler.py +++ b/tests/test_handler_parameter_scheduler.py @@ -48,7 +48,7 @@ def test_linear_scheduler(self): event=Events.EPOCH_COMPLETED, ).attach(engine) engine.run([0] * 8, max_epochs=3) - torch.testing.assert_allclose(net.get_value(), 3.33, atol=0.001, rtol=0.0) + torch.testing.assert_allclose(net.get_value(), 3.333333, atol=0.001, rtol=0.0) # Testing max_value net = ToyNet(value=-1)