Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Add support for Sync BN #423

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from classy_vision.hooks import (
CheckpointHook,
LossLrMeterLoggingHook,
ModelComplexityHook,
ProfilerHook,
ProgressBarHook,
TensorboardPlotHook,
Expand Down Expand Up @@ -118,7 +119,7 @@ def main(args, config):


def configure_hooks(args, config):
hooks = [LossLrMeterLoggingHook(args.log_freq)]
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]

# Make a folder to store checkpoints and tensorboard logging outputs
suffix = datetime.now().isoformat()
Expand Down
17 changes: 11 additions & 6 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@

import torch
import torch.nn as nn
from classy_vision.generic.util import get_model_dummy_input, is_leaf, is_on_gpu
from classy_vision.generic.util import (
eval_model,
get_model_dummy_input,
is_leaf,
is_on_gpu,
)
from torch.cuda import cudart


Expand All @@ -24,7 +29,6 @@ def profile(
"""
Performs CPU or GPU profiling of the specified model on the specified input.
"""

# assertions:
if use_nvprof:
raise NotImplementedError
Expand All @@ -41,8 +45,8 @@ def profile(
batchsize=batchsize_per_replica,
non_blocking=False,
)
# perform profiling:
with torch.no_grad():
# perform profiling in eval mode
with eval_model(model), torch.no_grad():
model(input) # warm up CUDA memory allocator and profiler
if use_nvprof: # nvprof profiling (TODO: Can we infer this?)
cudart().cudaProfilerStart()
Expand Down Expand Up @@ -376,7 +380,6 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
"""
Compute the complexity of a forward pass.
"""

# assertions, input, and upvalue in which we will perform the count:
assert isinstance(model, nn.Module)
if not isinstance(input_shape, abc.Sequence):
Expand All @@ -387,7 +390,9 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
# measure FLOPs:
modify_forward(model, compute_list, compute_fn)
try:
model.forward(input)
# compute complexity in eval mode
with eval_model(model), torch.no_grad():
model.forward(input)
except NotImplementedError as err:
raise err
finally:
Expand Down
30 changes: 30 additions & 0 deletions classy_vision/generic/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
import sys
import traceback
from functools import partial
from typing import Dict, Optional

import numpy as np
Expand Down Expand Up @@ -774,3 +775,32 @@ def get_model_dummy_input(
if input_key:
input = {input_key: input}
return input


@contextlib.contextmanager
def _train_mode(model: nn.Module, train_mode: bool):
"""Context manager which sets the train mode of a model. After returning, it
restores the state of every sub-module individually."""
train_modes = {}
for name, module in model.named_modules():
train_modes[name] = module.training
try:
model.train(train_mode)
yield
finally:
for name, module in model.named_modules():
module.training = train_modes[name]


train_model = partial(_train_mode, train_mode=True)
train_model.__doc__ = """Context manager which puts the model in train mode.

After returning, it restores the state of every sub-module individually.
"""


eval_model = partial(_train_mode, train_mode=False)
eval_model.__doc__ = """Context manager which puts the model in eval mode.

After returning, it restores the state of every sub-module individually.
"""
47 changes: 41 additions & 6 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, Dict, List, NamedTuple, Optional, Union

import torch
import torch.nn as nn
from classy_vision.dataset import ClassyDataset, build_dataset
from classy_vision.generic.distributed_util import (
all_reduce_mean,
Expand All @@ -23,7 +24,6 @@
recursive_copy_to_gpu,
update_classy_state,
)
from classy_vision.hooks import ClassyHookFunctions
from classy_vision.losses import ClassyLoss, build_loss
from classy_vision.meters import build_meters
from classy_vision.models import ClassyModel, build_model
Expand Down Expand Up @@ -54,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
BEFORE_EVAL = enum.auto()


class BatchNormSyncMode(enum.Enum):
DISABLED = enum.auto() # No Synchronized Batch Normalization
PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm
APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed


class LastBatchInfo(NamedTuple):
loss: torch.Tensor
output: torch.Tensor
Expand Down Expand Up @@ -134,6 +140,7 @@ def __init__(self):
self.amp_opt_level = None
self.perf_log = []
self.last_batch = None
self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED

def set_checkpoint(self, checkpoint):
"""Sets checkpoint on task.
Expand Down Expand Up @@ -205,14 +212,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
self.meters = meters
return self

def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode):
def set_distributed_options(
self,
broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED,
batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED,
):
"""Set distributed options.

Args:
broadcast_buffers_mode: Broadcast buffers mode. See
:class:`BroadcastBuffersMode` for options.
batch_norm_sync_mode: Batch normalization synchronization mode. See
:class:`BatchNormSyncMode` for options.

Raises:
RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex
is not installed.
"""
self.broadcast_buffers_mode = broadcast_buffers_mode

if batch_norm_sync_mode == BatchNormSyncMode.DISABLED:
logging.info("Synchronized Batch Normalization is disabled")
else:
if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available:
raise RuntimeError("apex is not installed")
logging.info(
f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}"
)
self.batch_norm_sync_mode = batch_norm_sync_mode

return self

def set_hooks(self, hooks: List["ClassyHook"]):
Expand Down Expand Up @@ -305,9 +333,6 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
amp_opt_level = config.get("amp_opt_level")
meters = build_meters(config.get("meters", {}))
model = build_model(config["model"])
# put model in eval mode in case any hooks modify model states, it'll
# be reset to train mode before training
model.eval()
optimizer = build_optimizer(optimizer_config)

task = (
Expand All @@ -321,7 +346,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
.set_meters(meters)
.set_amp_opt_level(amp_opt_level)
.set_distributed_options(
BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")]
broadcast_buffers_mode=BroadcastBuffersMode[
config.get("broadcast_buffers", "disabled").upper()
],
batch_norm_sync_mode=BatchNormSyncMode[
config.get("batch_norm_sync_mode", "disabled").upper()
],
)
)
for phase_type in phase_types:
Expand Down Expand Up @@ -498,6 +528,11 @@ def prepare(
multiprocessing_context=dataloader_mp_context,
)

if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model)
elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
self.base_model = apex.parallel.convert_syncbn_model(self.base_model)

# move the model and loss to the right device
if use_gpu:
self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
Expand Down
5 changes: 3 additions & 2 deletions test/generic/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ def get_test_mlp_task_config():
"num_classes": 2,
"crop_size": 20,
"class_ratio": 0.5,
"num_samples": 10,
"num_samples": 20,
"seed": 0,
"batchsize_per_replica": 3,
"batchsize_per_replica": 6,
"use_augmentation": False,
"use_shuffle": True,
"transforms": [
Expand Down Expand Up @@ -228,6 +228,7 @@ def get_test_mlp_task_config():
"input_dim": 1200,
"output_dim": 1000,
"hidden_dims": [10],
"use_batchnorm": True, # used for testing sync batchnorm
},
"meters": {"accuracy": {"topk": [1]}},
"optimizer": {
Expand Down
52 changes: 52 additions & 0 deletions test/generic_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import copy
import shutil
import tempfile
import unittest
Expand All @@ -14,6 +15,7 @@

import classy_vision.generic.util as util
import torch
import torch.nn as nn
from classy_vision.generic.util import (
CHECKPOINT_FILE,
load_checkpoint,
Expand Down Expand Up @@ -368,6 +370,56 @@ def test_get_model_dummy_input(self):
)
self.assertEqual(result.size(), tuple([batchsize] + input_shape))

def _compare_model_train_mode(self, model_1, model_2):
for name_1, module_1 in model_1.named_modules():
found = False
for name_2, module_2 in model_2.named_modules():
if name_1 == name_2:
found = True
if module_1.training != module_2.training:
return False
if not found:
return False
return True

def _check_model_train_mode(self, model, expected_mode):
for module in model.modules():
if module.training != expected_mode:
return False
return True

def test_train_model_eval_model(self):
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.dropout = nn.Dropout()
self.seq = nn.Sequential(
nn.ReLU(), nn.Conv2d(1, 2, 3), nn.BatchNorm2d(1, 2)
)

test_model = TestModel()
for train in [True, False]:
test_model.train(train)

# flip some of the modes
test_model.dropout.train(not train)
test_model.seq[1].train(not train)

orig_model = copy.deepcopy(test_model)

with util.train_model(test_model):
self._check_model_train_mode(test_model, True)
# the modes should be different inside the context manager
self.assertFalse(self._compare_model_train_mode(orig_model, test_model))
self.assertTrue(self._compare_model_train_mode(orig_model, test_model))

with util.eval_model(test_model):
self._check_model_train_mode(test_model, False)
# the modes should be different inside the context manager
self.assertFalse(self._compare_model_train_mode(orig_model, test_model))
self.assertTrue(self._compare_model_train_mode(orig_model, test_model))


class TestUpdateStateFunctions(unittest.TestCase):
def _compare_states(self, state_1, state_2, check_heads=True):
Expand Down
2 changes: 1 addition & 1 deletion test/hooks_loss_lr_meter_logging_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def scheduler_mock(where):
mock_lr_scheduler.update_interval = UpdateInterval.STEP
config = get_test_mlp_task_config()
config["num_epochs"] = 3
config["dataset"]["train"]["batchsize_per_replica"] = 5
config["dataset"]["train"]["batchsize_per_replica"] = 10
config["dataset"]["test"]["batchsize_per_replica"] = 5
task = build_task(config)
task.optimizer.param_schedulers["lr"] = mock_lr_scheduler
Expand Down
4 changes: 2 additions & 2 deletions test/manual/hooks_tensorboard_plot_hook_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def flush(self):

config = get_test_mlp_task_config()
config["num_epochs"] = 3
config["dataset"]["train"]["batchsize_per_replica"] = 5
config["dataset"]["train"]["batchsize_per_replica"] = 10
config["dataset"]["test"]["batchsize_per_replica"] = 5
task = build_task(config)

Expand All @@ -152,7 +152,7 @@ def flush(self):
trainer = LocalTrainer()
trainer.train(task)

# We have 10 samples, batch size is 5. Each epoch is done in two steps.
# We have 20 samples, batch size is 10. Each epoch is done in two steps.
self.assertEqual(
writer.scalar_logs["train_learning_rate_updates"],
[0, 1 / 6, 2 / 6, 3 / 6, 4 / 6, 5 / 6],
Expand Down
26 changes: 26 additions & 0 deletions test/trainer_distributed_trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ def setUp(self):
config = get_test_mlp_task_config()
invalid_config = copy.deepcopy(config)
invalid_config["name"] = "invalid_task"
sync_bn_config = copy.deepcopy(config)
sync_bn_config["sync_batch_norm_mode"] = "pytorch"
self.config_files = {}
for config_key, config in [
("config", config),
("invalid_config", invalid_config),
("sync_bn_config", sync_bn_config),
]:
with tempfile.NamedTemporaryFile(mode="w", delete=False) as f:
json.dump(config, f)
Expand Down Expand Up @@ -63,3 +66,26 @@ def test_training(self):
result = subprocess.run(cmd, shell=True)
success = result.returncode == 0
self.assertEqual(success, expected_success)

@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_sync_batch_norm(self):
"""Test that sync batch norm training doesn't hang."""

num_processes = 2
device = "gpu"

cmd = f"""{sys.executable} -m torch.distributed.launch \
--nnodes=1 \
--nproc_per_node={num_processes} \
--master_addr=localhost \
--master_port=29500 \
--use_env \
{self.path}/../classy_train.py \
--device={device} \
--config={self.config_files["sync_bn_config"]} \
--num_workers=4 \
--log_freq=100 \
--distributed_backend=ddp
"""
result = subprocess.run(cmd, shell=True)
self.assertEqual(result.returncode, 0)