Skip to content

Commit c8fe1fe

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
FSDP2 with Ghost Clipping and Fast Gradient Clipping prototyping (#761)
Summary: Integrating FSDP2 with Opacus First Prototype: 1. FSDP is supported only if all the layers with trainable parameters are supported by ghost clipping or fast gradient clipping. 3. No freezing/unfreezing of parameters in between the training. Design Doc: [Opacus Ghost Clipping and FSDP2](https://docs.google.com/document/d/1MHqIMKBAXhkUZYQ9kkHCmUs3iLq_G5Q25uS1u3g7Asw/edit?tab=t.0#heading=h.eqambyjwzqsu) Differential Revision: D70533184
1 parent f3752c3 commit c8fe1fe

File tree

9 files changed

+482
-6
lines changed

9 files changed

+482
-6
lines changed

opacus/grad_sample/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2323
GradSampleModuleFastGradientClipping,
2424
)
25+
from .grad_sample_module_fast_gradient_clipping_fsdp import ( # noqa
26+
GradSampleModuleFastGradientClippingFSDP,
27+
)
2528
from .group_norm import compute_group_norm_grad_sample # noqa
2629
from .gsm_base import AbstractGradSampleModule
2730
from .gsm_exp_weights import GradSampleModuleExpandedWeights
@@ -41,6 +44,7 @@
4144
__all__ = [
4245
"GradSampleModule",
4346
"GradSampleModuleFastGradientClipping",
47+
"GradSampleModuleFastGradientClippingFSDP",
4448
"GradSampleModuleExpandedWeights",
4549
"GradSampleModuleNoOp",
4650
"AbstractGradSampleModule",

opacus/grad_sample/grad_sample_module.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,12 @@ def add_hooks(
199199
if type(module) in [DPRNN, DPLSTM, DPGRU]:
200200
continue
201201

202-
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
202+
module_type = (
203+
module.__class__.__bases__[1]
204+
if isinstance(module, torch.distributed.fsdp.FSDPModule)
205+
else type(module)
206+
)
207+
if force_functorch or not (module_type in self.GRAD_SAMPLERS):
203208
prepare_layer(module, batch_first=batch_first)
204209

205210
self.autograd_grad_sample_hooks.append(
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import List
20+
21+
import torch
22+
import torch.nn as nn
23+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
24+
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
25+
GradSampleModuleFastGradientClipping,
26+
)
27+
from opacus.utils.module_utils import requires_grad, trainable_parameters
28+
29+
30+
logger = logging.getLogger(__name__)
31+
logger.disabled = True
32+
33+
34+
class GradSampleModuleFastGradientClippingFSDP(GradSampleModuleFastGradientClipping):
35+
"""
36+
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping and FSDP support
37+
38+
Computes norms of gradients without gradient instantiation
39+
"""
40+
41+
def __init__(
42+
self,
43+
m: nn.Module,
44+
*,
45+
batch_first=True,
46+
loss_reduction="mean",
47+
strict: bool = True,
48+
max_grad_norm=1,
49+
):
50+
"""
51+
52+
Args:
53+
m: nn.Module to be wrapped
54+
batch_first: Flag to indicate if the input tensor to the corresponding module
55+
has the first dimension representing the batch. If set to True, dimensions on
56+
input tensor are expected be ``[batch_size, ...]``, otherwise
57+
``[K, batch_size, ...]``
58+
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
59+
is a sum or a mean operation. Can take values "sum" or "mean"
60+
max_grad_norm: The value at which gradients are to be clipped.
61+
strict: If set to True, the input module will be validated to make sure that
62+
it does not have buffers in all its submodules.
63+
64+
Raises:
65+
NotImplementedError
66+
If ``strict`` is set to ``True`` and module ``m`` (or any of its
67+
submodules) includes a buffer.
68+
"""
69+
70+
super().__init__(
71+
m,
72+
batch_first=batch_first,
73+
loss_reduction=loss_reduction,
74+
strict=strict,
75+
force_functorch=False,
76+
max_grad_norm=max_grad_norm,
77+
use_ghost_clipping=True,
78+
)
79+
80+
self.sampler_classes = list(self.GRAD_SAMPLERS.keys()) + list(
81+
self.NORM_SAMPLERS.keys()
82+
)
83+
84+
def get_clipping_coef(self) -> torch.Tensor:
85+
"""Get per-example gradient scaling factor for clipping."""
86+
norm_sample = self.get_norm_sample()
87+
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
88+
89+
def get_norm_sample(self) -> torch.Tensor:
90+
"""Get per-example gradient norms."""
91+
norm_sample = torch.stack(
92+
[
93+
per_param_norm
94+
for module in self.iterate_submodules(self._module)
95+
for per_param_norm in module.norm_sample
96+
],
97+
dim=0,
98+
).norm(2, dim=0)
99+
100+
self.per_sample_gradient_norms = norm_sample
101+
return norm_sample
102+
103+
def capture_activations_hook(
104+
self,
105+
module: nn.Module,
106+
forward_input: List[torch.Tensor],
107+
_forward_output: torch.Tensor,
108+
):
109+
if (
110+
not requires_grad(module)
111+
or not module.training
112+
or not torch.is_grad_enabled()
113+
or not self.hooks_enabled
114+
):
115+
return
116+
117+
if not hasattr(module, "activations"):
118+
module.activations = []
119+
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
120+
121+
if not hasattr(module, "forward_counter"):
122+
module.forward_counter = 0
123+
124+
module.forward_counter += 1
125+
if self.use_ghost_clipping and module.forward_counter > 1:
126+
raise NotImplementedError("Parameter tying is not supported with FSDP")
127+
128+
def capture_backprops_hook(
129+
self,
130+
module: nn.Module,
131+
_forward_input: torch.Tensor,
132+
forward_output: torch.Tensor,
133+
loss_reduction: str,
134+
batch_first: bool,
135+
):
136+
"""
137+
Computes norms of per sample gradient given the current backprops and activations
138+
stored by the associated forward hook. Computed per sample gradient norms are
139+
stored in ``norm_sample`` field in each parameter.
140+
141+
Args:
142+
module: nn.Module,
143+
_forward_input: torch.Tensor,
144+
forward_output: torch.Tensor,
145+
loss_reduction: str,
146+
batch_first: bool,
147+
"""
148+
if not self.hooks_enabled:
149+
return
150+
151+
backprops = forward_output[0].detach()
152+
153+
activations, backprops = self.rearrange_grad_samples(
154+
module=module,
155+
backprops=backprops,
156+
loss_reduction=loss_reduction,
157+
batch_first=batch_first,
158+
)
159+
160+
if not hasattr(module, "norm_sample"):
161+
# currently, we don't support freezing and unfreezing params in between training. Making this a dictionary and mapping with param names might fix this.
162+
module.norm_sample = []
163+
for _, param in trainable_parameters(module):
164+
module.norm_sample.append(
165+
torch.zeros(
166+
torch.Size([module.max_batch_len, 1]),
167+
device=param.device,
168+
dtype=param.dtype,
169+
)
170+
)
171+
172+
module_type = (
173+
module.__class__.__bases__[1]
174+
if isinstance(module, torch.distributed.fsdp.FSDPModule)
175+
else type(module)
176+
)
177+
module.forward_counter -= 1
178+
if self.use_ghost_clipping and module_type in self.NORM_SAMPLERS:
179+
norm_sampler_fn = self.NORM_SAMPLERS[module_type]
180+
norm_samples = norm_sampler_fn(module, activations, backprops)
181+
182+
for idx, (_, ns) in enumerate(
183+
(item for item in norm_samples.items() if item[0].requires_grad)
184+
):
185+
module.norm_sample[idx] = ns
186+
else:
187+
if not self.force_functorch and module_type in self.GRAD_SAMPLERS:
188+
grad_sampler_fn = self.GRAD_SAMPLERS[module_type]
189+
else:
190+
grad_sampler_fn = ft_compute_per_sample_gradient
191+
192+
grad_samples = grad_sampler_fn(module, activations, backprops)
193+
194+
for idx, (_, gs) in enumerate((item for item in grad_samples.items())):
195+
module.norm_sample[idx] = gs.reshape(len(gs), -1).norm(2, dim=-1)
196+
del grad_samples
197+
198+
if len(module.activations) == 0:
199+
if hasattr(module, "max_batch_len"):
200+
del module.max_batch_len
201+
202+
@property
203+
def per_sample_gradient_norms(self) -> torch.Tensor:
204+
"""Returns per sample gradient norms. Note that these are not privatized and should only be used for debugging purposes or in non-private settings"""
205+
if self._per_sample_gradient_norms is not None:
206+
return self._per_sample_gradient_norms
207+
else:
208+
raise AttributeError(
209+
"per_sample_gradient_norms is not set. Please call forward and backward on the model before accessing this property."
210+
)
211+
212+
@per_sample_gradient_norms.setter
213+
def per_sample_gradient_norms(self, value):
214+
self._per_sample_gradient_norms = value

opacus/grad_sample/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .grad_sample_module_fast_gradient_clipping import (
2222
GradSampleModuleFastGradientClipping,
2323
)
24+
from .grad_sample_module_fast_gradient_clipping_fsdp import (
25+
GradSampleModuleFastGradientClippingFSDP,
26+
)
2427
from .gsm_base import AbstractGradSampleModule
2528
from .gsm_exp_weights import GradSampleModuleExpandedWeights
2629
from .gsm_no_op import GradSampleModuleNoOp
@@ -102,6 +105,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
102105
return GradSampleModuleExpandedWeights
103106
elif grad_sample_mode == "ghost":
104107
return GradSampleModuleFastGradientClipping
108+
elif grad_sample_mode == "ghost_fsdp":
109+
return GradSampleModuleFastGradientClippingFSDP
105110
elif grad_sample_mode == "no_op":
106111
return GradSampleModuleNoOp
107112
else:

opacus/optimizers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ddpoptimizer_fast_gradient_clipping import (
1919
DistributedDPOptimizerFastGradientClipping,
2020
)
21+
from .fsdpoptimizer_fast_gradient_clipping import FSDPOptimizerFastGradientClipping
2122
from .optimizer import DPOptimizer
2223
from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping
2324
from .perlayeroptimizer import DPPerLayerOptimizer
@@ -29,6 +30,7 @@
2930
"DPOptimizer",
3031
"DPOptimizerFastGradientClipping",
3132
"DistributedDPOptimizerFastGradientlipping",
33+
"FSDPOptimizerFastGradientClipping",
3234
"DPPerLayerOptimizer",
3335
"SimpleDistributedPerLayerOptimizer",
3436
]
@@ -44,6 +46,13 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
4446
raise ValueError(
4547
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
4648
)
49+
elif grad_sample_mode == "ghost_fsdp":
50+
if clipping == "flat" and distributed is True:
51+
return FSDPOptimizerFastGradientClipping
52+
else:
53+
raise ValueError(
54+
f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}"
55+
)
4756
elif clipping == "flat" and distributed is False:
4857
return DPOptimizer
4958
elif clipping == "flat" and distributed is True:

0 commit comments

Comments
 (0)