Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 4b406e8

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Add support for Sync BN (#423)
Summary: Pull Request resolved: #423 Added support for using sync batch normalization using PyTorch's implementation or Apex's. Plugged in the model complexity hook to `classy_train.py`. It helps test the bug I encountered and fixed which needs the profiler + sync batch norm. Reviewed By: vreis Differential Revision: D20307435 fbshipit-source-id: 8e3ccb3f55802f54f215f06610c2a01566fca1b0
1 parent d2016ef commit 4b406e8

File tree

4 files changed

+71
-4
lines changed

4 files changed

+71
-4
lines changed

classy_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from classy_vision.hooks import (
5151
CheckpointHook,
5252
LossLrMeterLoggingHook,
53+
ModelComplexityHook,
5354
ProfilerHook,
5455
ProgressBarHook,
5556
TensorboardPlotHook,
@@ -118,7 +119,7 @@ def main(args, config):
118119

119120

120121
def configure_hooks(args, config):
121-
hooks = [LossLrMeterLoggingHook(args.log_freq)]
122+
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]
122123

123124
# Make a folder to store checkpoints and tensorboard logging outputs
124125
suffix = datetime.now().isoformat()

classy_vision/tasks/classification_task.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Dict, List, NamedTuple, Optional, Union
1212

1313
import torch
14+
import torch.nn as nn
1415
from classy_vision.dataset import ClassyDataset, build_dataset
1516
from classy_vision.generic.distributed_util import (
1617
all_reduce_mean,
@@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
5354
BEFORE_EVAL = enum.auto()
5455

5556

57+
class BatchNormSyncMode(enum.Enum):
58+
DISABLED = enum.auto() # No Synchronized Batch Normalization
59+
PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm
60+
APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed
61+
62+
5663
class LastBatchInfo(NamedTuple):
5764
loss: torch.Tensor
5865
output: torch.Tensor
@@ -133,6 +140,7 @@ def __init__(self):
133140
self.amp_opt_level = None
134141
self.perf_log = []
135142
self.last_batch = None
143+
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED
136144

137145
def set_checkpoint(self, checkpoint):
138146
"""Sets checkpoint on task.
@@ -204,14 +212,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
204212
self.meters = meters
205213
return self
206214

207-
def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode):
215+
def set_distributed_options(
216+
self,
217+
broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED,
218+
batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
219+
):
208220
"""Set distributed options.
209221
210222
Args:
211223
broadcast_buffers_mode: Broadcast buffers mode. See
212224
:class:`BroadcastBuffersMode` for options.
225+
batch_norm_sync_mode: Batch normalization synchronization mode. See
226+
:class:`BatchNormSyncMode` for options.
227+
228+
Raises:
229+
RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
230+
is not installed.
213231
"""
214232
self.broadcast_buffers_mode = broadcast_buffers_mode
233+
234+
if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
235+
logging.info("Synchronized Batch Normalization is disabled")
236+
else:
237+
if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
238+
raise RuntimeError("apex is not installed")
239+
logging.info(
240+
f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
241+
)
242+
self.batch_norm_sync_mode = batch_norm_sync_mode
243+
215244
return self
216245

217246
def set_hooks(self, hooks: List["ClassyHook"]):
@@ -317,7 +346,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
317346
.set_meters(meters)
318347
.set_amp_opt_level(amp_opt_level)
319348
.set_distributed_options(
320-
BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")]
349+
broadcast_buffers_mode=BroadcastBuffersMode[
350+
config.get("broadcast_buffers", "disabled").upper()
351+
],
352+
batch_norm_sync_mode=BatchNormSyncMode[
353+
config.get("batch_norm_sync_mode", "disabled").upper()
354+
],
321355
)
322356
)
323357
for phase_type in phase_types:
@@ -494,6 +528,11 @@ def prepare(
494528
multiprocessing_context=dataloader_mp_context,
495529
)
496530

531+
if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
532+
self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
533+
elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
534+
self.base_model = apex.parallel.convert_syncbn_model(self.base_model)
535+
497536
# move the model and loss to the right device
498537
if use_gpu:
499538
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)

test/generic/config_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def get_test_mlp_task_config():
178178
"class_ratio": 0.5,
179179
"num_samples": 10,
180180
"seed": 0,
181-
"batchsize_per_replica": 3,
181+
"batchsize_per_replica": 4,
182182
"use_augmentation": False,
183183
"use_shuffle": True,
184184
"transforms": [
@@ -228,6 +228,7 @@ def get_test_mlp_task_config():
228228
"input_dim": 1200,
229229
"output_dim": 1000,
230230
"hidden_dims": [10],
231+
"use_batchnorm": True, # used for testing sync batchnorm
231232
},
232233
"meters": {"accuracy": {"topk": [1]}},
233234
"optimizer": {

test/trainer_distributed_trainer_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ def setUp(self):
2222
config = get_test_mlp_task_config()
2323
invalid_config = copy.deepcopy(config)
2424
invalid_config["name"] = "invalid_task"
25+
sync_bn_config = copy.deepcopy(config)
26+
sync_bn_config["sync_batch_norm_mode"] = "pytorch"
2527
self.config_files = {}
2628
for config_key, config in [
2729
("config", config),
2830
("invalid_config", invalid_config),
31+
("sync_bn_config", sync_bn_config),
2932
]:
3033
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
3134
json.dump(config, f)
@@ -63,3 +66,26 @@ def test_training(self):
6366
result = subprocess.run(cmd, shell=True)
6467
success = result.returncode == 0
6568
self.assertEqual(success, expected_success)
69+
70+
@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
71+
def test_sync_batch_norm(self):
72+
"""Test that sync batch norm training doesn't hang."""
73+
74+
num_processes = 2
75+
device = "gpu"
76+
77+
cmd = f"""{sys.executable} -m torch.distributed.launch \
78+
--nnodes=1 \
79+
--nproc_per_node={num_processes} \
80+
--master_addr=localhost \
81+
--master_port=29500 \
82+
--use_env \
83+
{self.path}/../classy_train.py \
84+
--device={device} \
85+
--config={self.config_files["sync_bn_config"]} \
86+
--num_workers=4 \
87+
--log_freq=100 \
88+
--distributed_backend=ddp
89+
"""
90+
result = subprocess.run(cmd, shell=True)
91+
self.assertEqual(result.returncode, 0)

0 commit comments

Comments
 (0)