Skip to content

Issue 1249 : fix ParamGroupScheduler with schedulers based on different optimizers #1274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Sep 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 42 additions & 21 deletions ignite/contrib/handlers/param_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import math
import numbers
import tempfile
Expand Down Expand Up @@ -380,8 +381,8 @@ class CosineAnnealingScheduler(CyclicalScheduler):

optimizer = SGD(
[
{"params": model.base.parameters(), 'lr': 0.001),
{"params": model.fc.parameters(), 'lr': 0.01),
{"params": model.base.parameters(), 'lr': 0.001},
{"params": model.fc.parameters(), 'lr': 0.01},
]
)

Expand Down Expand Up @@ -459,7 +460,7 @@ def __init__(self, schedulers, durations, save_history=False):
)

for i, scheduler in enumerate(schedulers):
if not isinstance(scheduler, ParamScheduler):
if not isinstance(scheduler, ParamScheduler) and not isinstance(scheduler, ParamGroupScheduler):
raise TypeError(
"Value at index {} of schedulers should be a parameter scheduler, "
"but given {}".format(i, type(scheduler))
Expand All @@ -468,15 +469,29 @@ def __init__(self, schedulers, durations, save_history=False):
self.schedulers = schedulers
self.durations = durations

self.optimizer = self.schedulers[0].optimizer
if not (all(id(s.optimizer) == id(self.optimizer) for s in self.schedulers)):
param_optimizers = [s.optimizer for s in self.schedulers]
param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers]
param_optimizers = list(itertools.chain(*param_optimizers))

optimizer = list(set(param_optimizers))
if len(optimizer) != 1:
raise ValueError("schedulers should be related to same optimizer")

param_names = [s.param_name for s in self.schedulers]
param_names = [s if isinstance(s, list) else [s] for s in param_names]
param_names = list(itertools.chain(*param_names))

param_name = list(set(param_names))
if len(param_name) != 1:
raise ValueError("schedulers should be related to same param_name")

# schedulers should have save_history sync with ParamGroupScheduler
for s in schedulers:
s.save_history = save_history

super(ConcatScheduler, self).__init__(optimizer=self.optimizer, param_name="", save_history=save_history)
super(ConcatScheduler, self).__init__(
optimizer=optimizer[0], param_name=param_name[0], save_history=save_history
)

self._scheduler_index = 0
self._current_scheduler = None
Expand Down Expand Up @@ -530,8 +545,6 @@ def _setup_scheduler(self):
self._current_duration = (
self.durations[self._scheduler_index] if self._scheduler_index < len(self.durations) else -1
)
self.param_name = self._current_scheduler.param_name
self.optimizer = self._current_scheduler.optimizer

def __call__(self, engine, name=None):
if self._current_duration == 0:
Expand Down Expand Up @@ -581,14 +594,24 @@ def simulate_values(cls, num_events, schedulers, durations, param_names=None, **
"Argument param_names should be list or tuple of strings, but given {}".format(param_names)
)

param_optimizers = [s.optimizer for s in schedulers]
param_optimizers = [s if isinstance(s, list) else [s] for s in param_optimizers]
param_optimizers = list(itertools.chain(*param_optimizers))

optimizer = list(set(param_optimizers))
if len(optimizer) != 1:
raise ValueError("schedulers should be related to same optimizer")

optimizer = optimizer[0]

# This scheduler uses `ParamScheduler` which
# should be replicated in order to simulate LR values and
# not perturb original scheduler.
with tempfile.TemporaryDirectory() as tmpdirname:
cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt"
objs = {"lr_scheduler_{}".format(i): s.state_dict() for i, s in enumerate(schedulers)}
# all schedulers should be related to the same optimizer
objs["optimizer"] = schedulers[0].optimizer.state_dict()
objs["optimizer"] = optimizer.state_dict()

torch.save(objs, cache_filepath.as_posix())

Expand All @@ -611,7 +634,7 @@ def simulate_values(cls, num_events, schedulers, durations, param_names=None, **
objs = torch.load(cache_filepath.as_posix())
for i, s in enumerate(schedulers):
s.load_state_dict(objs["lr_scheduler_{}".format(i)])
s.optimizer.load_state_dict(objs["optimizer"])
optimizer.load_state_dict(objs["optimizer"])

return output

Expand Down Expand Up @@ -640,7 +663,7 @@ class LRScheduler(ParamScheduler):
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
"""

def __init__(self, lr_scheduler, save_history=False, **kwargs):
def __init__(self, lr_scheduler, save_history=False):

if not isinstance(lr_scheduler, _LRScheduler):
raise TypeError(
Expand Down Expand Up @@ -927,7 +950,7 @@ def get_param(self):
return start_value + (end_value - start_value) * (self.event_index - start_index) / (end_index - start_index)


class ParamGroupScheduler(ParamScheduler):
class ParamGroupScheduler:
"""
Scheduler helper to group multiple schedulers into one.

Expand Down Expand Up @@ -981,20 +1004,21 @@ def __init__(self, schedulers: List[ParamScheduler], names: Optional[List[str]]
self.schedulers = schedulers
self.names = names

self.optimizer = self.schedulers[0].optimizer
if not (all(id(s.optimizer) == id(self.optimizer) for s in schedulers)):
raise ValueError("schedulers should be related to same optimizer")

# schedulers should have save_history sync with ParamGroupScheduler
for s in schedulers:
s.save_history = save_history

super(ParamGroupScheduler, self).__init__(optimizer=self.optimizer, param_name="lr", save_history=save_history)
self.optimizer = [s.optimizer for s in self.schedulers]
self.param_name = [s.param_name for s in self.schedulers]

def __call__(self, engine, name=None):
for scheduler, name in zip(self.schedulers, self.names):
scheduler(engine, name)

@property
def optimizer_param_groups(self):
return [pg for scheduler in self.schedulers for pg in scheduler.optimizer_param_groups]

@property
def save_history(self):
return self.schedulers[0].save_history
Expand All @@ -1004,9 +1028,6 @@ def save_history(self, value):
for s in self.schedulers:
s.save_history = value

def get_param(self) -> Union[List[float], float]:
return [scheduler.get_param() for scheduler in self.schedulers]

def state_dict(self):
"""Returns a dictionary containing a whole state of ParamGroupScheduler.

Expand Down Expand Up @@ -1077,7 +1098,7 @@ def simulate_values(cls, num_events, schedulers, **kwargs):
values = []
scheduler = cls(schedulers=schedulers, **kwargs)
for i in range(num_events):
params = scheduler.get_param()
params = [scheduler.get_param() for scheduler in schedulers]
values.append([i] + params)
scheduler(engine=None)

Expand Down
45 changes: 30 additions & 15 deletions tests/ignite/contrib/handlers/test_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,14 @@ def test_concat_scheduler_asserts():
with pytest.raises(ValueError, match=r"schedulers should be related to same optimizer"):
ConcatScheduler([scheduler_1, scheduler_3], durations=[30,])

scheduler_4 = CosineAnnealingScheduler(optimizer, "lr2", start_value=0.0, end_value=1.0, cycle_size=10)

with pytest.raises(ValueError, match=r"schedulers should be related to same param_name"):
ConcatScheduler([scheduler_1, scheduler_4], durations=[30,])

with pytest.raises(ValueError, match=r"schedulers should be related to same optimizer"):
ConcatScheduler.simulate_values(3, [scheduler_1, scheduler_3], durations=[30,])


def test_concat_scheduler_state_dict():
tensor = torch.zeros([1], requires_grad=True)
Expand Down Expand Up @@ -814,8 +822,7 @@ def _test(scheduler_cls, **scheduler_kwargs):
optimizer = None
event = Events.ITERATION_STARTED
if scheduler_cls == LRScheduler:
scheduler_kwargs["optimizer"] = scheduler_kwargs["lr_scheduler"].optimizer
optimizer = scheduler_kwargs["optimizer"]
optimizer = scheduler_kwargs["lr_scheduler"].optimizer
event = Events.ITERATION_COMPLETED
elif scheduler_cls == ConcatScheduler:
optimizer = scheduler_kwargs["optimizer"]
Expand Down Expand Up @@ -1157,16 +1164,9 @@ def test_param_group_scheduler_asserts():
):
scheduler.load_state_dict({"schedulers": [("a", lr_scheduler1.state_dict()), ("bad_name", {})]})

optimizer2 = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])
lr_scheduler3 = LinearCyclicalScheduler(
optimizer2, "lr", param_group_index=0, start_value=1.0, end_value=0.0, cycle_size=10
)
with pytest.raises(ValueError, match=r"schedulers should be related to same optimizer"):
ParamGroupScheduler(schedulers=[lr_scheduler1, lr_scheduler3])


def test_param_group_scheduler():
def _test(lr_schedulers, optimizer):
def _test(lr_schedulers, save_lr):
num_iterations = 10
max_epochs = 20

Expand All @@ -1175,17 +1175,15 @@ def _test(lr_schedulers, optimizer):

trainer = Engine(lambda engine, batch: None)

@trainer.on(Events.ITERATION_COMPLETED)
def save_lr(engine):
lrs.append((optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"]))

trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

data = [0] * num_iterations

for _ in range(2):
lrs = []
trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr, lrs)
trainer.run(data, max_epochs=max_epochs)
trainer.remove_event_handler(save_lr, Events.ITERATION_COMPLETED)
assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in lrs])
scheduler.load_state_dict(state_dict)

Expand All @@ -1203,7 +1201,24 @@ def save_lr(engine):
lr_scheduler2 = LinearCyclicalScheduler(
optimizer, "lr", param_group_index=1, start_value=1.0, end_value=0.0, cycle_size=10
)
_test([lr_scheduler1, lr_scheduler2], optimizer)

def save_lr_one_optimizer(engine, lrs):
lrs.append((optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"]))

_test([lr_scheduler1, lr_scheduler2], save_lr_one_optimizer)

t1 = torch.zeros([1], requires_grad=True)
optimizer_1 = torch.optim.SGD(params=[t1], lr=0.1)
t2 = torch.zeros([1], requires_grad=True)
optimizer_2 = torch.optim.SGD(params=[t2], lr=0.1)

lr_scheduler1 = LinearCyclicalScheduler(optimizer_1, "lr", start_value=1.0, end_value=0.0, cycle_size=10)
lr_scheduler2 = LinearCyclicalScheduler(optimizer_2, "lr", start_value=1.0, end_value=0.0, cycle_size=10)

def save_lr_mutliple_optimizers(engine, lrs):
lrs.append((optimizer_1.param_groups[0]["lr"], optimizer_2.param_groups[0]["lr"]))

_test([lr_scheduler1, lr_scheduler2], save_lr_mutliple_optimizers)


def test_scheduler_with_param_groups():
Expand Down