Skip to content

Commit e16b3e8

Browse files
lauragustafsonfacebook-github-bot
authored andcommitted
Exponential Parameter Scheduler (#715)
Summary: Pull Request resolved: facebookresearch/ClassyVision#715 Exponential parameter scheduler for fvcore/ClassyVision. This type of scheduler is useful as a LR scheduler when using EMA. Differential Revision: D27058657 fbshipit-source-id: 176ee5bcd074d5d517afa660a7dbf86f83468979
1 parent 1f43d07 commit e16b3e8

File tree

2 files changed

+57
-0
lines changed

2 files changed

+57
-0
lines changed

fvcore/common/param_scheduler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"ParamScheduler",
88
"ConstantParamScheduler",
99
"CosineParamScheduler",
10+
"ExponentialParamScheduler",
1011
"LinearParamScheduler",
1112
"CompositeParamScheduler",
1213
"MultiStepParamScheduler",
@@ -85,6 +86,33 @@ def __call__(self, where: float) -> float:
8586
)
8687

8788

89+
class ExponentialParamScheduler(ParamScheduler):
90+
"""
91+
Exponetial schedule based on start value and decay, where value is caluated
92+
for timestep t out of T total stepsm by
93+
param_t = start_value * (decay ** t/T).The schedule is updated after every train
94+
step by default based on the fraction of samples seen.
95+
96+
Example:
97+
98+
.. code-block:: python
99+
ExponentialParamScheduler(start_value=2.0, decay=0.02)
100+
101+
Corresponds to a decreasing schedule with values in [2.0, 0.04).
102+
"""
103+
104+
def __init__(
105+
self,
106+
start_value: float,
107+
decay: float,
108+
) -> None:
109+
self._start_value = start_value
110+
self._decay = decay
111+
112+
def __call__(self, where: float) -> float:
113+
return self._start_value * (self._decay ** where)
114+
115+
88116
class LinearParamScheduler(ParamScheduler):
89117
"""
90118
Linearly interpolates parameter between ``start_value`` and ``end_value``.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
3+
import unittest
4+
5+
from fvcore.common.param_scheduler import ExponentialParamScheduler
6+
7+
8+
class TestExponentialScheduler(unittest.TestCase):
9+
_num_epochs = 10
10+
11+
def _get_valid_config(self):
12+
return {"start_value": 2.0, "decay": 0.1}
13+
14+
def _get_valid_intermediate_values(self):
15+
return [1.5887, 1.2619, 1.0024, 0.7962, 0.6325, 0.5024, 0.3991, 0.3170, 0.2518]
16+
17+
def test_scheduler(self):
18+
config = self._get_valid_config()
19+
20+
scheduler = ExponentialParamScheduler(**config)
21+
schedule = [
22+
round(scheduler(epoch_num / self._num_epochs), 4)
23+
for epoch_num in range(self._num_epochs)
24+
]
25+
expected_schedule = [
26+
config["start_value"]
27+
] + self._get_valid_intermediate_values()
28+
29+
self.assertEqual(schedule, expected_schedule)

0 commit comments

Comments
 (0)