Skip to content

Commit d875c8c

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
FSDP2 with Ghost Clipping and Fast Gradient Clipping prototyping (#761)
Summary: Integrating FSDP2 with Opacus First Prototype: 1. FSDP is supported only if all the layers with trainable parameters are supported by ghost clipping or fast gradient clipping. 3. No freezing/unfreezing of parameters in between the training. Design Doc: [Opacus Ghost Clipping and FSDP2](https://docs.google.com/document/d/1MHqIMKBAXhkUZYQ9kkHCmUs3iLq_G5Q25uS1u3g7Asw/edit?tab=t.0#heading=h.eqambyjwzqsu) Differential Revision: D70533184
1 parent cb6284f commit d875c8c

File tree

9 files changed

+498
-11
lines changed

9 files changed

+498
-11
lines changed

opacus/grad_sample/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2323
GradSampleModuleFastGradientClipping,
2424
)
25+
from .grad_sample_module_fast_gradient_clipping_fsdp import ( # noqa
26+
GradSampleModuleFastGradientClippingFSDP,
27+
)
2528
from .group_norm import compute_group_norm_grad_sample # noqa
2629
from .gsm_base import AbstractGradSampleModule
2730
from .gsm_exp_weights import GradSampleModuleExpandedWeights
@@ -41,6 +44,7 @@
4144
__all__ = [
4245
"GradSampleModule",
4346
"GradSampleModuleFastGradientClipping",
47+
"GradSampleModuleFastGradientClippingFSDP",
4448
"GradSampleModuleExpandedWeights",
4549
"GradSampleModuleNoOp",
4650
"AbstractGradSampleModule",

opacus/grad_sample/grad_sample_module.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from functools import partial
2121
from typing import Iterable, List, Tuple
2222

23-
import torch
24-
import torch.nn as nn
2523
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
2624
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
2725
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
@@ -32,6 +30,9 @@
3230
trainable_parameters,
3331
)
3432

33+
import torch
34+
import torch.nn as nn
35+
3536

3637
logger = logging.getLogger(__name__)
3738
logger.disabled = True
@@ -199,7 +200,11 @@ def add_hooks(
199200
if type(module) in [DPRNN, DPLSTM, DPGRU]:
200201
continue
201202

202-
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
203+
module_type = next(
204+
(i for i in self.GRAD_SAMPLERS.keys() if isinstance(module, i)),
205+
type(module),
206+
)
207+
if force_functorch or not (module_type in self.GRAD_SAMPLERS):
203208
prepare_layer(module, batch_first=batch_first)
204209

205210
self.autograd_grad_sample_hooks.append(
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import List
20+
21+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
22+
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
23+
GradSampleModuleFastGradientClipping,
24+
)
25+
from opacus.utils.module_utils import requires_grad, trainable_parameters
26+
27+
import torch
28+
import torch.nn as nn
29+
30+
31+
logger = logging.getLogger(__name__)
32+
logger.disabled = True
33+
34+
35+
class GradSampleModuleFastGradientClippingFSDP(GradSampleModuleFastGradientClipping):
36+
"""
37+
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping
38+
39+
Computes norms of gradients without gradient instantiation
40+
"""
41+
42+
def __init__(
43+
self,
44+
m: nn.Module,
45+
*,
46+
batch_first=True,
47+
loss_reduction="mean",
48+
strict: bool = True,
49+
max_grad_norm=1,
50+
):
51+
"""
52+
53+
Args:
54+
m: nn.Module to be wrapped
55+
batch_first: Flag to indicate if the input tensor to the corresponding module
56+
has the first dimension representing the batch. If set to True, dimensions on
57+
input tensor are expected be ``[batch_size, ...]``, otherwise
58+
``[K, batch_size, ...]``
59+
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
60+
is a sum or a mean operation. Can take values "sum" or "mean"
61+
max_grad_norm: The value at which gradients are to be clipped.
62+
strict: If set to True, the input module will be validated to make sure that
63+
it does not have buffers in all its submodules.
64+
65+
Raises:
66+
NotImplementedError
67+
If ``strict`` is set to ``True`` and module ``m`` (or any of its
68+
submodules) includes a buffer.
69+
"""
70+
71+
super().__init__(
72+
m,
73+
batch_first=batch_first,
74+
loss_reduction=loss_reduction,
75+
strict=strict,
76+
force_functorch=False,
77+
max_grad_norm=max_grad_norm,
78+
use_ghost_clipping=True,
79+
)
80+
81+
self.sampler_classes = list(self.GRAD_SAMPLERS.keys()) + list(
82+
self.NORM_SAMPLERS.keys()
83+
)
84+
self.module_types = {}
85+
for m in self.iterate_submodules(m):
86+
if type(m) not in self.module_types:
87+
module_type = next(
88+
(i for i in self.sampler_classes if isinstance(m, i)),
89+
type(m),
90+
)
91+
self.module_types[type(m)] = module_type
92+
93+
def get_clipping_coef(self) -> torch.Tensor:
94+
"""Get per-example gradient scaling factor for clipping."""
95+
norm_sample = self.get_norm_sample()
96+
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
97+
98+
def get_norm_sample(self) -> torch.Tensor:
99+
"""Get per-example gradient norms."""
100+
norm_sample = torch.stack(
101+
[
102+
per_param_norm
103+
for module in self.iterate_submodules(self._module)
104+
for per_param_norm in module.norm_sample
105+
],
106+
dim=0,
107+
).norm(2, dim=0)
108+
109+
self.per_sample_gradient_norms = norm_sample
110+
return norm_sample
111+
112+
def capture_activations_hook(
113+
self,
114+
module: nn.Module,
115+
forward_input: List[torch.Tensor],
116+
_forward_output: torch.Tensor,
117+
):
118+
if (
119+
not requires_grad(module)
120+
or not module.training
121+
or not torch.is_grad_enabled()
122+
or not self.hooks_enabled
123+
):
124+
return
125+
126+
if not hasattr(module, "activations"):
127+
module.activations = []
128+
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
129+
130+
if not hasattr(module, "forward_counter"):
131+
module.forward_counter = 0
132+
133+
module.forward_counter += 1
134+
if self.use_ghost_clipping and module.forward_counter > 1:
135+
raise NotImplementedError("Parameter tying is not supported with FSDP")
136+
137+
def capture_backprops_hook(
138+
self,
139+
module: nn.Module,
140+
_forward_input: torch.Tensor,
141+
forward_output: torch.Tensor,
142+
loss_reduction: str,
143+
batch_first: bool,
144+
):
145+
"""
146+
Computes norms of per sample gradient given the current backprops and activations
147+
stored by the associated forward hook. Computed per sample gradient norms are
148+
stored in ``norm_sample`` field in each parameter.
149+
150+
Args:
151+
module: nn.Module,
152+
_forward_input: torch.Tensor,
153+
forward_output: torch.Tensor,
154+
loss_reduction: str,
155+
batch_first: bool,
156+
"""
157+
if not self.hooks_enabled:
158+
return
159+
160+
backprops = forward_output[0].detach()
161+
162+
activations, backprops = self.rearrange_grad_samples(
163+
module=module,
164+
backprops=backprops,
165+
loss_reduction=loss_reduction,
166+
batch_first=batch_first,
167+
)
168+
169+
if not hasattr(module, "norm_sample"):
170+
# currently, we don't support freezing and unfreezing params in between training. Making this a dictionary and mapping with param names might fix this.
171+
module.norm_sample = []
172+
for _, param in trainable_parameters(module):
173+
module.norm_sample.append(
174+
torch.zeros(
175+
torch.Size([module.max_batch_len, 1]),
176+
device=param.device,
177+
dtype=param.dtype,
178+
)
179+
)
180+
181+
module_type = self.module_types[type(module)]
182+
module.forward_counter -= 1
183+
if self.use_ghost_clipping and module_type in self.NORM_SAMPLERS:
184+
norm_sampler_fn = self.NORM_SAMPLERS[module_type]
185+
norm_samples = norm_sampler_fn(module, activations, backprops)
186+
187+
for idx, (_, ns) in enumerate(
188+
(item for item in norm_samples.items() if item[0].requires_grad)
189+
):
190+
module.norm_sample[idx] = ns
191+
else:
192+
if not self.force_functorch and module_type in self.GRAD_SAMPLERS:
193+
grad_sampler_fn = self.GRAD_SAMPLERS[module_type]
194+
else:
195+
grad_sampler_fn = ft_compute_per_sample_gradient
196+
197+
grad_samples = grad_sampler_fn(module, activations, backprops)
198+
199+
for idx, (_, gs) in enumerate((item for item in grad_samples.items())):
200+
module.norm_sample[idx] = gs.reshape(len(gs), -1).norm(2, dim=-1)
201+
del grad_samples
202+
203+
if len(module.activations) == 0:
204+
if hasattr(module, "max_batch_len"):
205+
del module.max_batch_len
206+
207+
@property
208+
def per_sample_gradient_norms(self) -> torch.Tensor:
209+
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
210+
if self._per_sample_gradient_norms is not None:
211+
return self._per_sample_gradient_norms
212+
else:
213+
raise AttributeError(
214+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
215+
)
216+
217+
@per_sample_gradient_norms.setter
218+
def per_sample_gradient_norms(self, value):
219+
self._per_sample_gradient_norms = value

opacus/grad_sample/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .grad_sample_module_fast_gradient_clipping import (
2222
GradSampleModuleFastGradientClipping,
2323
)
24+
from .grad_sample_module_fast_gradient_clipping_fsdp import (
25+
GradSampleModuleFastGradientClippingFSDP,
26+
)
2427
from .gsm_base import AbstractGradSampleModule
2528
from .gsm_exp_weights import GradSampleModuleExpandedWeights
2629
from .gsm_no_op import GradSampleModuleNoOp
@@ -102,6 +105,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
102105
return GradSampleModuleExpandedWeights
103106
elif grad_sample_mode == "ghost":
104107
return GradSampleModuleFastGradientClipping
108+
elif grad_sample_mode == "ghost_fsdp":
109+
return GradSampleModuleFastGradientClippingFSDP
105110
elif grad_sample_mode == "no_op":
106111
return GradSampleModuleNoOp
107112
else:

opacus/optimizers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ddpoptimizer_fast_gradient_clipping import (
1919
DistributedDPOptimizerFastGradientClipping,
2020
)
21+
from .fsdpoptimizer_fast_gradient_clipping import FSDPOptimizerFastGradientClipping
2122
from .optimizer import DPOptimizer
2223
from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping
2324
from .perlayeroptimizer import DPPerLayerOptimizer
@@ -29,7 +30,7 @@
2930
"DPOptimizer",
3031
"DPOptimizerFastGradientClipping",
3132
"DistributedDPOptimizerFastGradientlipping",
32-
"DPPerLayerOptimizer",
33+
"FSDPOptimizerFastGradientClipping" "DPPerLayerOptimizer",
3334
"SimpleDistributedPerLayerOptimizer",
3435
]
3536

@@ -44,6 +45,13 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
4445
raise ValueError(
4546
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
4647
)
48+
elif grad_sample_mode == "ghost_fsdp":
49+
if clipping == "flat" and distributed is True:
50+
return FSDPOptimizerFastGradientClipping
51+
else:
52+
raise ValueError(
53+
f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}"
54+
)
4755
elif clipping == "flat" and distributed is False:
4856
return DPOptimizer
4957
elif clipping == "flat" and distributed is True:

0 commit comments

Comments
 (0)