Skip to content

Commit f6413b3

Browse files
aparna-aketifacebook-github-bot
authored andcommitted
FSDP2 with Ghost Clipping and Fast Gradient Clipping prototyping (#761)
Summary: Integrating FSDP2 with Opacus First Prototype: 1. FSDP is supported only if all the layers with trainable parameters are supported by ghost clipping or fast gradient clipping. 3. No freezing/unfreezing of parameters in between the training. Design Doc: [Opacus Ghost Clipping and FSDP2](https://docs.google.com/document/d/1MHqIMKBAXhkUZYQ9kkHCmUs3iLq_G5Q25uS1u3g7Asw/edit?tab=t.0#heading=h.eqambyjwzqsu) Differential Revision: D70533184
1 parent f3752c3 commit f6413b3

File tree

9 files changed

+477
-8
lines changed

9 files changed

+477
-8
lines changed

opacus/grad_sample/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from .grad_sample_module_fast_gradient_clipping import ( # noqa
2323
GradSampleModuleFastGradientClipping,
2424
)
25+
from .grad_sample_module_fast_gradient_clipping_fsdp import ( # noqa
26+
GradSampleModuleFastGradientClippingFSDP,
27+
)
2528
from .group_norm import compute_group_norm_grad_sample # noqa
2629
from .gsm_base import AbstractGradSampleModule
2730
from .gsm_exp_weights import GradSampleModuleExpandedWeights
@@ -41,6 +44,7 @@
4144
__all__ = [
4245
"GradSampleModule",
4346
"GradSampleModuleFastGradientClipping",
47+
"GradSampleModuleFastGradientClippingFSDP",
4448
"GradSampleModuleExpandedWeights",
4549
"GradSampleModuleNoOp",
4650
"AbstractGradSampleModule",

opacus/grad_sample/grad_sample_module.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def iterate_submodules(self, module: nn.Module) -> Iterable[nn.Module]:
163163
for m in module.children():
164164
yield from self.iterate_submodules(m)
165165

166+
def get_module_type(self, module: nn.Module) -> str:
167+
return type(module)
168+
166169
def add_hooks(
167170
self,
168171
*,
@@ -199,7 +202,8 @@ def add_hooks(
199202
if type(module) in [DPRNN, DPLSTM, DPGRU]:
200203
continue
201204

202-
if force_functorch or not type(module) in self.GRAD_SAMPLERS:
205+
module_type = self.get_module_type(module)
206+
if force_functorch or not (module_type in self.GRAD_SAMPLERS):
203207
prepare_layer(module, batch_first=batch_first)
204208

205209
self.autograd_grad_sample_hooks.append(
@@ -330,8 +334,11 @@ def capture_backprops_hook(
330334
loss_reduction=loss_reduction,
331335
batch_first=batch_first,
332336
)
333-
if not self.force_functorch and type(module) in self.GRAD_SAMPLERS:
334-
grad_sampler_fn = self.GRAD_SAMPLERS[type(module)]
337+
if (
338+
not self.force_functorch
339+
and self.get_module_type(module) in self.GRAD_SAMPLERS
340+
):
341+
grad_sampler_fn = self.GRAD_SAMPLERS[self.get_module_type(module)]
335342
else:
336343
grad_sampler_fn = ft_compute_per_sample_gradient
337344

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import List
20+
21+
import torch
22+
import torch.nn as nn
23+
from opacus.grad_sample.functorch import ft_compute_per_sample_gradient
24+
from opacus.grad_sample.grad_sample_module_fast_gradient_clipping import (
25+
GradSampleModuleFastGradientClipping,
26+
)
27+
from opacus.utils.module_utils import requires_grad, trainable_parameters
28+
29+
30+
logger = logging.getLogger(__name__)
31+
logger.disabled = True
32+
33+
34+
class GradSampleModuleFastGradientClippingFSDP(GradSampleModuleFastGradientClipping):
35+
"""
36+
Hooks-based implementation of GradSampleModule with Fast Gradient and Ghost Clipping and FSDP support
37+
38+
Computes norms of gradients without gradient instantiation
39+
"""
40+
41+
def __init__(
42+
self,
43+
m: nn.Module,
44+
*,
45+
batch_first=True,
46+
loss_reduction="mean",
47+
strict: bool = True,
48+
max_grad_norm=1,
49+
):
50+
"""
51+
52+
Args:
53+
m: nn.Module to be wrapped
54+
batch_first: Flag to indicate if the input tensor to the corresponding module
55+
has the first dimension representing the batch. If set to True, dimensions on
56+
input tensor are expected be ``[batch_size, ...]``, otherwise
57+
``[K, batch_size, ...]``
58+
loss_reduction: Indicates if the loss reduction (for aggregating the gradients)
59+
is a sum or a mean operation. Can take values "sum" or "mean"
60+
max_grad_norm: The value at which gradients are to be clipped.
61+
strict: If set to True, the input module will be validated to make sure that
62+
it does not have buffers in all its submodules.
63+
64+
Raises:
65+
NotImplementedError
66+
If ``strict`` is set to ``True`` and module ``m`` (or any of its
67+
submodules) includes a buffer.
68+
"""
69+
70+
super().__init__(
71+
m,
72+
batch_first=batch_first,
73+
loss_reduction=loss_reduction,
74+
strict=strict,
75+
force_functorch=False,
76+
max_grad_norm=max_grad_norm,
77+
use_ghost_clipping=True,
78+
)
79+
80+
def get_module_type(self, module: nn.Module) -> str:
81+
module_type = (
82+
module.__class__.__bases__[1]
83+
if isinstance(module, torch.distributed.fsdp.FSDPModule)
84+
else type(module)
85+
)
86+
return module_type
87+
88+
def get_clipping_coef(self) -> torch.Tensor:
89+
"""Get per-example gradient scaling factor for clipping."""
90+
norm_sample = self.get_norm_sample()
91+
return (self.max_grad_norm / (norm_sample + 1e-6)).clamp(max=1.0)
92+
93+
def get_norm_sample(self) -> torch.Tensor:
94+
"""Get per-example gradient norms."""
95+
norm_sample = torch.stack(
96+
[
97+
per_param_norm
98+
for module in self.iterate_submodules(self._module)
99+
for per_param_norm in module.norm_sample
100+
],
101+
dim=0,
102+
).norm(2, dim=0)
103+
104+
self.per_sample_gradient_norms = norm_sample
105+
return norm_sample
106+
107+
def capture_activations_hook(
108+
self,
109+
module: nn.Module,
110+
forward_input: List[torch.Tensor],
111+
_forward_output: torch.Tensor,
112+
):
113+
"""Captures activations for the given module.
114+
This function is similar to the capture_activations_hook in the parent class (GradSampleModuleFastGradientClipping),
115+
except that it attaches _forward_counter to the module instead of parameter variable.
116+
Another difference is that GradSampleModuleFastGradientClipping doesn't support tied parameters only for Ghost Clipping,
117+
But this class doesn't supports tied parameters for both Fast Gradient Clipping and Ghost Clipping.
118+
"""
119+
if (
120+
not requires_grad(module)
121+
or not module.training
122+
or not torch.is_grad_enabled()
123+
or not self.hooks_enabled
124+
):
125+
return
126+
127+
if not hasattr(module, "activations"):
128+
module.activations = []
129+
module.activations.append([t.detach() for t in forward_input]) # pyre-ignore
130+
131+
if not hasattr(module, "_forward_counter"):
132+
module._forward_counter = 0
133+
134+
module._forward_counter += 1
135+
if self.use_ghost_clipping and module._forward_counter > 1:
136+
raise NotImplementedError("Parameter tying is not supported with FSDP")
137+
138+
def capture_backprops_hook(
139+
self,
140+
module: nn.Module,
141+
_forward_input: torch.Tensor,
142+
forward_output: torch.Tensor,
143+
loss_reduction: str,
144+
batch_first: bool,
145+
):
146+
"""
147+
Computes norms of per sample gradient given the current backprops and activations
148+
stored by the associated forward hook. Computed per sample gradient norms are
149+
stored in ``norm_sample`` field in each parameter.
150+
151+
Args:
152+
module: nn.Module,
153+
_forward_input: torch.Tensor,
154+
forward_output: torch.Tensor,
155+
loss_reduction: str,
156+
batch_first: bool,
157+
"""
158+
if not self.hooks_enabled:
159+
return
160+
161+
backprops = forward_output[0].detach()
162+
163+
activations, backprops = self.rearrange_grad_samples(
164+
module=module,
165+
backprops=backprops,
166+
loss_reduction=loss_reduction,
167+
batch_first=batch_first,
168+
)
169+
170+
if not hasattr(module, "norm_sample"):
171+
# currently, we don't support freezing and unfreezing params in between training. Making this a dictionary and mapping with param names might fix this.
172+
module.norm_sample = []
173+
for _, param in trainable_parameters(module):
174+
module.norm_sample.append(
175+
torch.zeros(
176+
torch.Size([module.max_batch_len, 1]),
177+
device=param.device,
178+
dtype=param.dtype,
179+
)
180+
)
181+
182+
module_type = self.get_module_type(module)
183+
module._forward_counter -= 1
184+
if self.use_ghost_clipping and module_type in self.NORM_SAMPLERS:
185+
norm_sampler_fn = self.NORM_SAMPLERS[module_type]
186+
norm_samples = norm_sampler_fn(module, activations, backprops)
187+
188+
for idx, (_, ns) in enumerate(
189+
(item for item in norm_samples.items() if item[0].requires_grad)
190+
):
191+
module.norm_sample[idx] = ns
192+
else:
193+
if not self.force_functorch and module_type in self.GRAD_SAMPLERS:
194+
grad_sampler_fn = self.GRAD_SAMPLERS[module_type]
195+
else:
196+
grad_sampler_fn = ft_compute_per_sample_gradient
197+
198+
grad_samples = grad_sampler_fn(module, activations, backprops)
199+
200+
for idx, (_, gs) in enumerate((item for item in grad_samples.items())):
201+
module.norm_sample[idx] = gs.reshape(len(gs), -1).norm(2, dim=-1)
202+
del grad_samples
203+
204+
if len(module.activations) == 0:
205+
if hasattr(module, "max_batch_len"):
206+
del module.max_batch_len

opacus/grad_sample/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from .grad_sample_module_fast_gradient_clipping import (
2222
GradSampleModuleFastGradientClipping,
2323
)
24+
from .grad_sample_module_fast_gradient_clipping_fsdp import (
25+
GradSampleModuleFastGradientClippingFSDP,
26+
)
2427
from .gsm_base import AbstractGradSampleModule
2528
from .gsm_exp_weights import GradSampleModuleExpandedWeights
2629
from .gsm_no_op import GradSampleModuleNoOp
@@ -102,6 +105,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
102105
return GradSampleModuleExpandedWeights
103106
elif grad_sample_mode == "ghost":
104107
return GradSampleModuleFastGradientClipping
108+
elif grad_sample_mode == "ghost_fsdp":
109+
return GradSampleModuleFastGradientClippingFSDP
105110
elif grad_sample_mode == "no_op":
106111
return GradSampleModuleNoOp
107112
else:

opacus/optimizers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .ddpoptimizer_fast_gradient_clipping import (
1919
DistributedDPOptimizerFastGradientClipping,
2020
)
21+
from .fsdpoptimizer_fast_gradient_clipping import FSDPOptimizerFastGradientClipping
2122
from .optimizer import DPOptimizer
2223
from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping
2324
from .perlayeroptimizer import DPPerLayerOptimizer
@@ -29,6 +30,7 @@
2930
"DPOptimizer",
3031
"DPOptimizerFastGradientClipping",
3132
"DistributedDPOptimizerFastGradientlipping",
33+
"FSDPOptimizerFastGradientClipping",
3234
"DPPerLayerOptimizer",
3335
"SimpleDistributedPerLayerOptimizer",
3436
]
@@ -44,6 +46,13 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
4446
raise ValueError(
4547
f"Unsupported combination of parameters. Clipping: {clipping} and grad_sample_mode: {grad_sample_mode}"
4648
)
49+
elif grad_sample_mode == "ghost_fsdp":
50+
if clipping == "flat" and distributed is True:
51+
return FSDPOptimizerFastGradientClipping
52+
else:
53+
raise ValueError(
54+
f"Unsupported combination of parameters. Clipping: {clipping}, distributed: {distributed}, and grad_sample_mode: {grad_sample_mode}"
55+
)
4756
elif clipping == "flat" and distributed is False:
4857
return DPOptimizer
4958
elif clipping == "flat" and distributed is True:

0 commit comments

Comments
 (0)