Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 4add1b3

Browse files
prigoyalfacebook-github-bot
authored andcommitted
Support LARC for SGD optimizer only in classy vision (#408)
Summary: Pull Request resolved: #408 Pull Request resolved: fairinternal/ClassyVision#64 In an attempt to implement SimpleCLR for contrastive losses, I needed LARC to enable large batch training. mannatsingh had already done work on this during classy vision open source release. https://our.intern.facebook.com/intern/diff/D18542126/ I initially tried using that diff to have a separate standalone LARC to work for any optimizer but it turned out to be tricky to setup correctly as we need to wrap a given optimizer in LARC (the `getattr` and `setattr` functions were not working). I talked to vreis about it and we decided that for now, we can support it for SGD only, file a task to support other optimizers too later after discussions with mannatsingh once he's back. Differential Revision: D20139718 fbshipit-source-id: c8cf4d545e6ce94cca8e646f68d519197856f675
1 parent db87bb8 commit 4add1b3

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

classy_vision/optim/classy_optimizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,10 @@ def step(self, closure: Optional[Callable] = None):
278278
Args:
279279
closure: A closure that re-evaluates the model and returns the loss
280280
"""
281-
self.optimizer.step(closure)
281+
if closure is None:
282+
self.optimizer.step()
283+
else:
284+
self.optimizer.step(closure)
282285

283286
def zero_grad(self):
284287
"""

classy_vision/optim/sgd.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,21 @@
1515
class SGD(ClassyOptimizer):
1616
def __init__(
1717
self,
18+
larc_config: Dict[str, Any] = None,
1819
lr: float = 0.1,
1920
momentum: float = 0.0,
2021
weight_decay: float = 0.0,
2122
nesterov: bool = False,
23+
use_larc: bool = False,
2224
):
2325
super().__init__()
2426

2527
self.parameters.lr = lr
2628
self.parameters.momentum = momentum
2729
self.parameters.weight_decay = weight_decay
2830
self.parameters.nesterov = nesterov
31+
self.parameters.use_larc = use_larc
32+
self.larc_config = larc_config
2933

3034
def init_pytorch_optimizer(self, model, **kwargs):
3135
super().init_pytorch_optimizer(model, **kwargs)
@@ -36,6 +40,12 @@ def init_pytorch_optimizer(self, model, **kwargs):
3640
momentum=self.parameters.momentum,
3741
weight_decay=self.parameters.weight_decay,
3842
)
43+
if self.parameters.use_larc:
44+
try:
45+
from apex.parallel.LARC import LARC
46+
except ImportError:
47+
raise RuntimeError("Apex needed for LARC")
48+
self.optimizer = LARC(optimizer=self.optimizer, **self.larc_config)
3949

4050
@classmethod
4151
def from_config(cls, config: Dict[str, Any]) -> "SGD":
@@ -53,6 +63,10 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
5363
config.setdefault("momentum", 0.0)
5464
config.setdefault("weight_decay", 0.0)
5565
config.setdefault("nesterov", False)
66+
config.setdefault("use_larc", False)
67+
config.setdefault(
68+
"larc_config", {"clip": True, "eps": 1e-08, "trust_coefficient": 0.02}
69+
)
5670

5771
assert (
5872
config["momentum"] >= 0.0
@@ -62,10 +76,15 @@ def from_config(cls, config: Dict[str, Any]) -> "SGD":
6276
assert isinstance(
6377
config["nesterov"], bool
6478
), "Config must contain a boolean 'nesterov' param for SGD optimizer"
79+
assert isinstance(
80+
config["use_larc"], bool
81+
), "Config must contain a boolean 'use_larc' param for SGD optimizer"
6582

6683
return cls(
84+
larc_config=config["larc_config"],
6785
lr=config["lr"],
6886
momentum=config["momentum"],
6987
weight_decay=config["weight_decay"],
7088
nesterov=config["nesterov"],
89+
use_larc=config["use_larc"],
7190
)

0 commit comments

Comments
 (0)