Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit af5f89b

Browse files
karthikprasadfacebook-github-bot
authored andcommitted
Support DP in PyText (#1366)
Summary: Pull Request resolved: #1366 Pull Request resolved: #1355 as titled Reviewed By: snisarg, ashkan-software Differential Revision: D20844321 fbshipit-source-id: 0825df81462a76b192e06d1e13bcfc8cf64155b8
1 parent b0a9d80 commit af5f89b

File tree

7 files changed

+100
-2
lines changed

7 files changed

+100
-2
lines changed

docs_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ sentencepiece
1313
torchtext
1414
tensorboard==1.14
1515
pandas
16+
pytorch-dp

pytext/config/component.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class ComponentType(enum.Enum):
3232
METRIC_REPORTER = "metric_reporter"
3333
SPARSIFIER = "sparsifier"
3434
MASKING_FUNCTION = "masking_function"
35+
PRIVACY_ENGINE = "privacy_engine"
3536

3637

3738
class RegistryError(Exception):
@@ -247,6 +248,12 @@ def create_sparsifier(sparsifier_config, *args, **kwargs):
247248
)
248249

249250

251+
def create_privacy_engine(privacy_engine_config, *args, **kwargs):
252+
return create_component(
253+
ComponentType.PRIVACY_ENGINE, privacy_engine_config, *args, **kwargs
254+
)
255+
256+
250257
def create_predictor(predictor_config, *args, **kwargs):
251258
return create_component(ComponentType.PREDICTOR, predictor_config, *args, **kwargs)
252259

pytext/optimizer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
Optimizer,
1616
learning_rates,
1717
)
18+
from pytext.optimizer.privacy_engine import PrivacyEngine # noqa
1819
from pytext.optimizer.radam import RAdam # noqa
1920
from pytext.optimizer.swa import StochasticWeightAveraging # noqa

pytext/optimizer/privacy_engine.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from typing import List, Optional
5+
6+
import torchdp
7+
from pytext.config import ConfigBase
8+
from pytext.config.component import Component, ComponentType
9+
10+
11+
class PrivacyEngine(Component):
12+
"""
13+
A wrapper around PrivacyEngine of pytorch-dp
14+
"""
15+
16+
__COMPONENT_TYPE__ = ComponentType.PRIVACY_ENGINE
17+
__EXPANSIBLE__ = False
18+
19+
class Config(ConfigBase):
20+
noise_multiplier: float
21+
max_grad_norm: float
22+
batch_size: float
23+
dataset_size: float
24+
target_delta: Optional[float] = 0.000001
25+
alphas: Optional[List[float]] = [1 + x / 10.0 for x in range(1, 100)] + list(
26+
range(12, 64)
27+
)
28+
29+
def __init__(
30+
self,
31+
model,
32+
optimizer,
33+
noise_multiplier,
34+
max_grad_norm,
35+
batch_size,
36+
dataset_size,
37+
target_delta,
38+
alphas,
39+
):
40+
self.noise_multiplier = noise_multiplier
41+
self.max_grad_norm = max_grad_norm
42+
self.batch_size = batch_size
43+
self.dataset_size = dataset_size
44+
self.target_delta = target_delta
45+
self.alphas = alphas
46+
47+
self._privacy_engine = torchdp.PrivacyEngine(
48+
model,
49+
self.batch_size,
50+
self.dataset_size,
51+
self.alphas,
52+
noise_multiplier=self.noise_multiplier,
53+
max_grad_norm=self.max_grad_norm,
54+
target_delta=self.target_delta,
55+
)
56+
self._privacy_engine.attach(optimizer)
57+
58+
@classmethod
59+
def from_config(cls, config: Config, model, optimizer):
60+
return cls(
61+
model=model,
62+
optimizer=optimizer,
63+
noise_multiplier=config.noise_multiplier,
64+
max_grad_norm=config.max_grad_norm,
65+
batch_size=config.batch_size,
66+
dataset_size=config.dataset_size,
67+
target_delta=config.target_delta,
68+
alphas=config.alphas,
69+
)
70+
71+
def attach(self, optimizer):
72+
self._privacy_engine.attach(optimizer)
73+
74+
def detach(self):
75+
self._privacy_engine.detach()
76+
77+
def get_privacy_spent(self):
78+
return self._privacy_engine.get_privacy_spent()

pytext/trainers/trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Component,
1414
ComponentType,
1515
create_optimizer,
16+
create_privacy_engine,
1617
create_scheduler,
1718
create_sparsifier,
1819
)
@@ -21,7 +22,7 @@
2122
from pytext.metric_reporters import MetricReporter
2223
from pytext.models.distributed_model import DistributedModel
2324
from pytext.models.model import Model
24-
from pytext.optimizer import Adam, Optimizer, learning_rates
25+
from pytext.optimizer import Adam, Optimizer, PrivacyEngine, learning_rates
2526
from pytext.optimizer.fp16_optimizer import FP16Optimizer, FP16OptimizerFairseq
2627
from pytext.optimizer.scheduler import Scheduler
2728
from pytext.optimizer.sparsifiers.sparsifier import Sparsifier
@@ -119,6 +120,8 @@ class Config(ConfigBase):
119120
#: backward and master weight will be maintained on original optimizer.
120121
#: https://arxiv.org/abs/1710.03740
121122
fp16_args: FP16Optimizer.Config = FP16OptimizerFairseq.Config()
123+
# PrivacyEngine related args
124+
privacy_engine: Optional[PrivacyEngine.Config] = None
122125

123126
def __init__(self, config: Config, model: torch.nn.Module):
124127
if config.early_stop_after > 0:
@@ -135,6 +138,11 @@ def __init__(self, config: Config, model: torch.nn.Module):
135138
self.optimizer: torch.optim.Optimizer = create_optimizer(
136139
config.optimizer, model
137140
)
141+
self.privacy_engine: PrivacyEngine = (
142+
create_privacy_engine(config.privacy_engine, model, self.optimizer)
143+
if config.privacy_engine
144+
else None
145+
)
138146

139147
self.scheduler: torch.optim.lr_scheduler = (
140148
create_scheduler(config.scheduler, self.optimizer)
@@ -370,6 +378,7 @@ def train(
370378
optimizer=self.optimizer,
371379
scheduler=self.scheduler,
372380
sparsifier=self.sparsifier,
381+
privacy_engine=self.privacy_engine,
373382
rank=rank,
374383
)
375384
return self.train_from_state(

pytext/trainers/training_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytext.common.constants import Stage
88
from pytext.data.tensorizers import Tensorizer
99
from pytext.models.model import Model
10-
from pytext.optimizer import Optimizer
10+
from pytext.optimizer import Optimizer, PrivacyEngine
1111
from pytext.optimizer.scheduler import Scheduler
1212
from pytext.optimizer.sparsifiers.sparsifier import Sparsifier
1313

@@ -18,6 +18,7 @@ class TrainingState:
1818
scheduler: Scheduler
1919
sparsifier: Sparsifier
2020
start_time: float
21+
privacy_engine: PrivacyEngine
2122
# epoch counter
2223
epoch: int = 0
2324
# step counter: each optimizer.step() increments step_counter

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ onnx>=1.6.0
1010
python-dateutil==2.8.0
1111
pandas
1212
pytorch-pretrained-bert
13+
pytorch-dp
1314
regex==2019.11.1
1415
requests
1516
scipy

0 commit comments

Comments
 (0)