Skip to content

Commit 0aa8971

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
FSDP2 with Ghost Clipping and Fast Gradient Clipping prototyping (#761)
Summary: Integrating FSDP2 with Opacus First Prototype: 1. FSDP is supported only if all the layers with trainable parameters are supported by ghost clipping or fast gradient clipping. 3. No freezing/unfreezing of parameters in between the training. Design Doc: [Opacus Ghost Clipping and FSDP2](https://docs.google.com/document/d/1MHqIMKBAXhkUZYQ9kkHCmUs3iLq_G5Q25uS1u3g7Asw/edit?tab=t.0#heading=h.eqambyjwzqsu) Differential Revision: D70533184
1 parent f3752c3 commit 0aa8971

File tree

9 files changed

+491
-10
lines changed

9 files changed

+491
-10
lines changed

opacus/grad_sample/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2323
GradSampleModuleFastGradientClipping,
2424
)
25+
from .grad_sample_module_fast_gradient_clipping_fsdp import ( # noqa
26+
GradSampleModuleFastGradientClippingFSDP,
27+
)
2528
from .group_norm import compute_group_norm_grad_sample # noqa
2629
from .gsm_base import AbstractGradSampleModule
2730
from .gsm_exp_weights import GradSampleModuleExpandedWeights
@@ -41,6 +44,7 @@
4144
__all__ = [
4245
"GradSampleModule",
4346
"GradSampleModuleFastGradientClipping",
47+
"GradSampleModuleFastGradientClippingFSDP",
4448
"GradSampleModuleExpandedWeights",
4549
"GradSampleModuleNoOp",
4650
"AbstractGradSampleModule",

opacus/grad_sample/grad_sample_module.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
from functools import partial
2121
from typing import Iterable, List, Tuple
2222

23-
import torch
24-
import torch.nn as nn
2523
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient, prepare_layer
2624
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
2725
from opacus.layers.dp_rnn import DPGRU, DPLSTM, DPRNN, RNNLinear
@@ -32,6 +30,9 @@
3230
trainable_parameters,
3331
)
3432

33+
import torch
34+
import torch.nn as nn
35+
3536

3637
logger = logging.getLogger(__name__)
3738
logger.disabled = True
@@ -199,7 +200,12 @@ def add_hooks(
199200
if type(module) in [DPRNN, DPLSTM, DPGRU]:
200201
continue
201202

202-
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
203+
module_type = (
204+
module.__class__.__bases__[1]
205+
if isinstance(module, torch.distributed.fsdp.FSDPModule)
206+
else type(module)
207+
)
208+
if force_functorch or not (module_type in self.GRAD_SAMPLERS):
203209
prepare_layer(module, batch_first=batch_first)
204210

205211
self.autograd_grad_sample_hooks.append(
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import List
20+
21+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
22+
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
23+
GradSampleModuleFastGradientClipping,
24+
)
25+
from opacus.utils.module_utils import requires_grad, trainable_parameters
26+
27+
import torch
28+
import torch.nn as nn
29+
30+
31+
logger = logging.getLogger(__name__)
32+
logger.disabled = True
33+
34+
35+
class GradSampleModuleFastGradientClippingFSDP(GradSampleModuleFastGradientClipping):
36+
"""
37+
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping and FSDP support
38+
39+
Computes norms of gradients without gradient instantiation
40+
"""
41+
42+
def __init__(
43+
self,
44+
m: nn.Module,
45+
*,
46+
batch_first=True,
47+
loss_reduction="mean",
48+
strict: bool = True,
49+
max_grad_norm=1,
50+
):
51+
"""
52+
53+
Args:
54+
m: nn.Module to be wrapped
55+
batch_first: Flag to indicate if the input tensor to the corresponding module
56+
has the first dimension representing the batch. If set to True, dimensions on
57+
input tensor are expected be ``[batch_size, ...]``, otherwise
58+
``[K, batch_size, ...]``
59+
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
60+
is a sum or a mean operation. Can take values "sum" or "mean"
61+
max_grad_norm: The value at which gradients are to be clipped.
62+
strict: If set to True, the input module will be validated to make sure that
63+
it does not have buffers in all its submodules.
64+
65+
Raises:
66+
NotImplementedError
67+
If ``strict`` is set to ``True`` and module ``m`` (or any of its
68+
submodules) includes a buffer.
69+
"""
70+
71+
super().__init__(
72+
m,
73+
batch_first=batch_first,
74+
loss_reduction=loss_reduction,
75+
strict=strict,
76+
force_functorch=False,
77+
max_grad_norm=max_grad_norm,
78+
use_ghost_clipping=True,
79+
)
80+
81+
self.sampler_classes = list(self.GRAD_SAMPLERS.keys()) + list(
82+
self.NORM_SAMPLERS.keys()
83+
)
84+
85+
def get_clipping_coef(self) -> torch.Tensor:
86+
"""Get per-example gradient scaling factor for clipping."""
87+
norm_sample = self.get_norm_sample()
88+
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
89+
90+
def get_norm_sample(self) -> torch.Tensor:
91+
"""Get per-example gradient norms."""
92+
norm_sample = torch.stack(
93+
[
94+
per_param_norm
95+
for module in self.iterate_submodules(self._module)
96+
for per_param_norm in module.norm_sample
97+
],
98+
dim=0,
99+
).norm(2, dim=0)
100+
101+
self.per_sample_gradient_norms = norm_sample
102+
return norm_sample
103+
104+
def capture_activations_hook(
105+
self,
106+
module: nn.Module,
107+
forward_input: List[torch.Tensor],
108+
_forward_output: torch.Tensor,
109+
):
110+
if (
111+
not requires_grad(module)
112+
or not module.training
113+
or not torch.is_grad_enabled()
114+
or not self.hooks_enabled
115+
):
116+
return
117+
118+
if not hasattr(module, "activations"):
119+
module.activations = []
120+
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
121+
122+
if not hasattr(module, "forward_counter"):
123+
module.forward_counter = 0
124+
125+
module.forward_counter += 1
126+
if self.use_ghost_clipping and module.forward_counter > 1:
127+
raise NotImplementedError("Parameter tying is not supported with FSDP")
128+
129+
def capture_backprops_hook(
130+
self,
131+
module: nn.Module,
132+
_forward_input: torch.Tensor,
133+
forward_output: torch.Tensor,
134+
loss_reduction: str,
135+
batch_first: bool,
136+
):
137+
"""
138+
Computes norms of per sample gradient given the current backprops and activations
139+
stored by the associated forward hook. Computed per sample gradient norms are
140+
stored in ``norm_sample`` field in each parameter.
141+
142+
Args:
143+
module: nn.Module,
144+
_forward_input: torch.Tensor,
145+
forward_output: torch.Tensor,
146+
loss_reduction: str,
147+
batch_first: bool,
148+
"""
149+
if not self.hooks_enabled:
150+
return
151+
152+
backprops = forward_output[0].detach()
153+
154+
activations, backprops = self.rearrange_grad_samples(
155+
module=module,
156+
backprops=backprops,
157+
loss_reduction=loss_reduction,
158+
batch_first=batch_first,
159+
)
160+
161+
if not hasattr(module, "norm_sample"):
162+
# currently, we don't support freezing and unfreezing params in between training. Making this a dictionary and mapping with param names might fix this.
163+
module.norm_sample = []
164+
for _, param in trainable_parameters(module):
165+
module.norm_sample.append(
166+
torch.zeros(
167+
torch.Size([module.max_batch_len, 1]),
168+
device=param.device,
169+
dtype=param.dtype,
170+
)
171+
)
172+
173+
module_type = (
174+
module.__class__.__bases__[1]
175+
if isinstance(module, torch.distributed.fsdp.FSDPModule)
176+
else type(module)
177+
)
178+
module.forward_counter -= 1
179+
if self.use_ghost_clipping and module_type in self.NORM_SAMPLERS:
180+
norm_sampler_fn = self.NORM_SAMPLERS[module_type]
181+
norm_samples = norm_sampler_fn(module, activations, backprops)
182+
183+
for idx, (_, ns) in enumerate(
184+
(item for item in norm_samples.items() if item[0].requires_grad)
185+
):
186+
module.norm_sample[idx] = ns
187+
else:
188+
if not self.force_functorch and module_type in self.GRAD_SAMPLERS:
189+
grad_sampler_fn = self.GRAD_SAMPLERS[module_type]
190+
else:
191+
grad_sampler_fn = ft_compute_per_sample_gradient
192+
193+
grad_samples = grad_sampler_fn(module, activations, backprops)
194+
195+
for idx, (_, gs) in enumerate((item for item in grad_samples.items())):
196+
module.norm_sample[idx] = gs.reshape(len(gs), -1).norm(2, dim=-1)
197+
del grad_samples
198+
199+
if len(module.activations) == 0:
200+
if hasattr(module, "max_batch_len"):
201+
del module.max_batch_len
202+
203+
@property
204+
def per_sample_gradient_norms(self) -> torch.Tensor:
205+
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
206+
if self._per_sample_gradient_norms is not None:
207+
return self._per_sample_gradient_norms
208+
else:
209+
raise AttributeError(
210+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
211+
)
212+
213+
@per_sample_gradient_norms.setter
214+
def per_sample_gradient_norms(self, value):
215+
self._per_sample_gradient_norms = value

opacus/grad_sample/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .grad_sample_module_fast_gradient_clipping import (
2222
GradSampleModuleFastGradientClipping,
2323
)
24+
from .grad_sample_module_fast_gradient_clipping_fsdp import (
25+
GradSampleModuleFastGradientClippingFSDP,
26+
)
2427
from .gsm_base import AbstractGradSampleModule
2528
from .gsm_exp_weights import GradSampleModuleExpandedWeights
2629
from .gsm_no_op import GradSampleModuleNoOp
@@ -102,6 +105,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
102105
return GradSampleModuleExpandedWeights
103106
elif grad_sample_mode == "ghost":
104107
return GradSampleModuleFastGradientClipping
108+
elif grad_sample_mode == "ghost_fsdp":
109+
return GradSampleModuleFastGradientClippingFSDP
105110
elif grad_sample_mode == "no_op":
106111
return GradSampleModuleNoOp
107112
else:

opacus/optimizers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ddpoptimizer_fast_gradient_clipping import (
1919
DistributedDPOptimizerFastGradientClipping,
2020
)
21+
from .fsdpoptimizer_fast_gradient_clipping import FSDPOptimizerFastGradientClipping
2122
from .optimizer import DPOptimizer
2223
from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping
2324
from .perlayeroptimizer import DPPerLayerOptimizer
@@ -29,6 +30,7 @@
2930
"DPOptimizer",
3031
"DPOptimizerFastGradientClipping",
3132
"DistributedDPOptimizerFastGradientlipping",
33+
"FSDPOptimizerFastGradientClipping",
3234
"DPPerLayerOptimizer",
3335
"SimpleDistributedPerLayerOptimizer",
3436
]
@@ -44,6 +46,13 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
4446
raise ValueError(
4547
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
4648
)
49+
elif grad_sample_mode == "ghost_fsdp":
50+
if clipping == "flat" and distributed is True:
51+
return FSDPOptimizerFastGradientClipping
52+
else:
53+
raise ValueError(
54+
f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}"
55+
)
4756
elif clipping == "flat" and distributed is False:
4857
return DPOptimizer
4958
elif clipping == "flat" and distributed is True:

0 commit comments

Comments
 (0)