Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 9f1d91c

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Add support for Sync BN (#423)
Summary: Pull Request resolved: #423 Added support for using sync batch normalization using PyTorch's implementation or Apex's. Plugged in the model complexity hook to `classy_train.py`. It helps test the bug I encountered and fixed which needs the profiler + sync batch norm. Differential Revision: D20307435 fbshipit-source-id: 82010fc2ed41ac4bdaedf5d8b1b1ddca49bf8d12
1 parent 5eb850b commit 9f1d91c

File tree

4 files changed

+69
-3
lines changed

4 files changed

+69
-3
lines changed

classy_train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from classy_vision.hooks import (
5151
CheckpointHook,
5252
LossLrMeterLoggingHook,
53+
ModelComplexityHook,
5354
ProfilerHook,
5455
ProgressBarHook,
5556
TensorboardPlotHook,
@@ -118,7 +119,7 @@ def main(args, config):
118119

119120

120121
def configure_hooks(args, config):
121-
hooks = [LossLrMeterLoggingHook(args.log_freq)]
122+
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]
122123

123124
# Make a folder to store checkpoints and tensorboard logging outputs
124125
suffix = datetime.now().isoformat()

classy_vision/tasks/classification_task.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any, Dict, List, NamedTuple, Optional, Union
1212

1313
import torch
14+
import torch.nn as nn
1415
from classy_vision.dataset import ClassyDataset, build_dataset
1516
from classy_vision.generic.distributed_util import (
1617
all_reduce_mean,
@@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
5354
BEFORE_EVAL = enum.auto()
5455

5556

57+
class BatchNormSyncMode(enum.Enum):
58+
DISABLED = enum.auto() # No Synchronized Batch Normalization
59+
PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm
60+
APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed
61+
62+
5663
class LastBatchInfo(NamedTuple):
5764
loss: torch.Tensor
5865
output: torch.Tensor
@@ -204,14 +211,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
204211
self.meters = meters
205212
return self
206213

207-
def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode):
214+
def set_distributed_options(
215+
self,
216+
broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED,
217+
batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
218+
):
208219
"""Set distributed options.
209220
210221
Args:
211222
broadcast_buffers_mode: Broadcast buffers mode. See
212223
:class:`BroadcastBuffersMode` for options.
224+
batch_norm_sync_mode: Batch normalization synchronization mode. See
225+
:class:`BatchNormSyncMode` for options.
226+
227+
Raises:
228+
RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
229+
is not installed.
213230
"""
214231
self.broadcast_buffers_mode = broadcast_buffers_mode
232+
233+
if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
234+
logging.info("Synchronized Batch Normalization is disabled")
235+
else:
236+
if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
237+
raise RuntimeError("apex is not installed")
238+
logging.info(
239+
f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
240+
)
241+
self.batch_norm_sync_mode = batch_norm_sync_mode
242+
215243
return self
216244

217245
def set_hooks(self, hooks: List["ClassyHook"]):
@@ -317,7 +345,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
317345
.set_meters(meters)
318346
.set_amp_opt_level(amp_opt_level)
319347
.set_distributed_options(
320-
BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")]
348+
broadcast_buffers_mode=BroadcastBuffersMode[
349+
config.get("broadcast_buffers", "disabled").upper()
350+
],
351+
batch_norm_sync_mode=BatchNormSyncMode[
352+
config.get("batch_norm_sync_mode", "disabled").upper()
353+
],
321354
)
322355
)
323356
for phase_type in phase_types:
@@ -494,6 +527,11 @@ def prepare(
494527
multiprocessing_context=dataloader_mp_context,
495528
)
496529

530+
if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
531+
self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
532+
elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
533+
self.base_model = apex.parallel.convert_syncbn_model(self.base_model)
534+
497535
# move the model and loss to the right device
498536
if use_gpu:
499537
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)

test/generic/config_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def get_test_mlp_task_config():
228228
"input_dim": 1200,
229229
"output_dim": 1000,
230230
"hidden_dims": [10],
231+
"use_batchnorm": True, # used for testing sync batchnorm
231232
},
232233
"meters": {"accuracy": {"topk": [1]}},
233234
"optimizer": {

test/trainer_distributed_trainer_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ def setUp(self):
2222
config = get_test_mlp_task_config()
2323
invalid_config = copy.deepcopy(config)
2424
invalid_config["name"] = "invalid_task"
25+
sync_bn_config = copy.deepcopy(config)
26+
sync_bn_config["sync_batch_norm_mode"] = "pytorch"
2527
self.config_files = {}
2628
for config_key, config in [
2729
("config", config),
2830
("invalid_config", invalid_config),
31+
("sync_bn_config", sync_bn_config),
2932
]:
3033
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
3134
json.dump(config, f)
@@ -63,3 +66,26 @@ def test_training(self):
6366
result = subprocess.run(cmd, shell=True)
6467
success = result.returncode == 0
6568
self.assertEqual(success, expected_success)
69+
70+
@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
71+
def test_sync_batch_norm(self):
72+
"""Test that sync batch norm training doesn't hang."""
73+
74+
num_processes = 2
75+
device = "gpu"
76+
77+
cmd = f"""{sys.executable} -m torch.distributed.launch \
78+
--nnodes=1 \
79+
--nproc_per_node={num_processes} \
80+
--master_addr=localhost \
81+
--master_port=29500 \
82+
--use_env \
83+
{self.path}/../classy_train.py \
84+
--device={device} \
85+
--config={self.config_files["sync_bn_config"]} \
86+
--num_workers=4 \
87+
--log_freq=100 \
88+
--distributed_backend=ddp
89+
"""
90+
result = subprocess.run(cmd, shell=True)
91+
self.assertEqual(result.returncode, 0)

0 commit comments

Comments
 (0)