Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit bd9888b

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Add support for Sync BN
Summary: Added support for using sync batch normalization using PyTorch's implementation or Apex's. Plugged in the model complexity hook to `classy_train.py`. It helps test the bug I encountered and fixed which needs the profiler + sync batch norm. Differential Revision: D20307435 fbshipit-source-id: cf93dd50fff06c1d809f97c4267d4af9934564bb
1 parent 636740b commit bd9888b

File tree

3 files changed

+59
-1
lines changed

3 files changed

+59
-1
lines changed

classy_vision/tasks/classification_task.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(self):
134134
self.amp_opt_level = None
135135
self.perf_log = []
136136
self.last_batch = None
137+
self.sync_batch_norm = False
137138

138139
def set_checkpoint(self, checkpoint):
139140
"""Sets checkpoint on task.
@@ -282,6 +283,29 @@ def set_amp_opt_level(self, opt_level: Optional[str]):
282283
logging.info(f"AMP enabled with opt_level {opt_level}")
283284
return self
284285

286+
def set_sync_batch_norm(self, sync_batch_norm: bool) -> "ClassificationTask":
287+
"""Enable / disable sync batch norm.
288+
289+
Args:
290+
sync_batch_norm: Set to True to enable and False otherwise.
291+
Raises:
292+
RuntimeError: If sync_batch_norm is True and apex is not installed.
293+
294+
Warning: apex needs to be installed to utilize this feature.
295+
"""
296+
self.sync_batch_norm = sync_batch_norm
297+
if sync_batch_norm:
298+
"""
299+
if not apex_available:
300+
raise RuntimeError(
301+
"apex is not installed, cannot enable sync_batch_norm"
302+
)
303+
"""
304+
logging.info("Using Synchronized Batch Normalization")
305+
else:
306+
logging.info("Synchronized Batch Normalization is disabled")
307+
return self
308+
285309
@classmethod
286310
def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
287311
"""Instantiates a ClassificationTask from a configuration.
@@ -303,6 +327,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
303327
loss = build_loss(config["loss"])
304328
test_only = config.get("test_only", False)
305329
amp_opt_level = config.get("amp_opt_level")
330+
sync_batch_norm = config.get("sync_batch_norm", False)
306331
meters = build_meters(config.get("meters", {}))
307332
model = build_model(config["model"])
308333
# put model in eval mode in case any hooks modify model states, it'll
@@ -320,6 +345,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
320345
.set_optimizer(optimizer)
321346
.set_meters(meters)
322347
.set_amp_opt_level(amp_opt_level)
348+
.set_sync_batch_norm(sync_batch_norm)
323349
.set_distributed_options(
324350
BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")]
325351
)
@@ -498,6 +524,10 @@ def prepare(
498524
multiprocessing_context=dataloader_mp_context,
499525
)
500526

527+
if self.sync_batch_norm:
528+
# self.base_model = apex.parallel.convert_syncbn_model(self.base_model)
529+
self.base_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
530+
501531
# move the model and loss to the right device
502532
if use_gpu:
503533
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
@@ -585,6 +615,8 @@ def get_classy_state(self, deep_copy: bool = False):
585615
Args:
586616
deep_copy: If true, does a deep copy of state before returning.
587617
"""
618+
from classy_vision.generic.distributed_util import get_world_size
619+
print("World size", get_world_size())
588620
classy_state_dict = {
589621
"train": self.train,
590622
"base_model": self.base_model.get_classy_state(),

test/generic/config_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def get_test_mlp_task_config():
228228
"input_dim": 1200,
229229
"output_dim": 1000,
230230
"hidden_dims": [10],
231+
"use_batchnorm": True, # used for testing sync batchnorm
231232
},
232233
"meters": {"accuracy": {"topk": [1]}},
233234
"optimizer": {

test/trainer_distributed_trainer_test.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ def setUp(self):
2222
config = get_test_mlp_task_config()
2323
invalid_config = copy.deepcopy(config)
2424
invalid_config["name"] = "invalid_task"
25+
sync_bn_config = copy.deepcopy(config)
26+
sync_bn_config["sync_batch_norm"] = True
2527
self.config_files = {}
2628
for config_key, config in [
2729
("config", config),
2830
("invalid_config", invalid_config),
31+
("sync_bn_config", sync_bn_config),
2932
]:
3033
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
3134
json.dump(config, f)
@@ -37,7 +40,7 @@ def tearDown(self):
3740
for config_file in self.config_files.values():
3841
os.unlink(config_file)
3942

40-
def test_training(self):
43+
def _test_training(self):
4144
"""Checks we can train a small MLP model."""
4245

4346
num_processes = 2
@@ -63,3 +66,25 @@ def test_training(self):
6366
result = subprocess.run(cmd, shell=True)
6467
success = result.returncode == 0
6568
self.assertEqual(success, expected_success)
69+
70+
def test_sync_batch_norm(self):
71+
"""Test that sync batch norm training doesn't hang."""
72+
73+
num_processes = 2
74+
device = "gpu" if torch.cuda.is_available() else "cpu"
75+
76+
cmd = f"""{sys.executable} -m torch.distributed.launch \
77+
--nnodes=1 \
78+
--nproc_per_node={num_processes} \
79+
--master_addr=localhost \
80+
--master_port=29500 \
81+
--use_env \
82+
{self.path}/../classy_train.py \
83+
--device={device} \
84+
--config={self.config_files["sync_bn_config"]} \
85+
--num_workers=4 \
86+
--log_freq=100 \
87+
--distributed_backend=ddp
88+
"""
89+
result = subprocess.run(cmd, shell=True)
90+
self.assertEqual(result.returncode, 0)

0 commit comments

Comments
 (0)