|
11 | 11 | from typing import Any, Dict, List, NamedTuple, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | +import torch.nn as nn |
14 | 15 | from classy_vision.dataset import ClassyDataset, build_dataset
|
15 | 16 | from classy_vision.generic.distributed_util import (
|
16 | 17 | all_reduce_mean,
|
@@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
|
53 | 54 | BEFORE_EVAL = enum.auto()
|
54 | 55 |
|
55 | 56 |
|
| 57 | +class BatchNormSyncMode(enum.Enum): |
| 58 | + DISABLED = enum.auto() # No Synchronized Batch Normalization |
| 59 | + PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm |
| 60 | + APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed |
| 61 | + |
| 62 | + |
56 | 63 | class LastBatchInfo(NamedTuple):
|
57 | 64 | loss: torch.Tensor
|
58 | 65 | output: torch.Tensor
|
@@ -133,6 +140,7 @@ def __init__(self):
|
133 | 140 | self.amp_opt_level = None
|
134 | 141 | self.perf_log = []
|
135 | 142 | self.last_batch = None
|
| 143 | + self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED |
136 | 144 |
|
137 | 145 | def set_checkpoint(self, checkpoint):
|
138 | 146 | """Sets checkpoint on task.
|
@@ -204,14 +212,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
|
204 | 212 | self.meters = meters
|
205 | 213 | return self
|
206 | 214 |
|
207 |
| - def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode): |
| 215 | + def set_distributed_options( |
| 216 | + self, |
| 217 | + broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED, |
| 218 | + batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED, |
| 219 | + ): |
208 | 220 | """Set distributed options.
|
209 | 221 |
|
210 | 222 | Args:
|
211 | 223 | broadcast_buffers_mode: Broadcast buffers mode. See
|
212 | 224 | :class:`BroadcastBuffersMode` for options.
|
| 225 | + batch_norm_sync_mode: Batch normalization synchronization mode. See |
| 226 | + :class:`BatchNormSyncMode` for options. |
| 227 | +
|
| 228 | + Raises: |
| 229 | + RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex |
| 230 | + is not installed. |
213 | 231 | """
|
214 | 232 | self.broadcast_buffers_mode = broadcast_buffers_mode
|
| 233 | + |
| 234 | + if batch_norm_sync_mode == BatchNormSyncMode.DISABLED: |
| 235 | + logging.info("Synchronized Batch Normalization is disabled") |
| 236 | + else: |
| 237 | + if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available: |
| 238 | + raise RuntimeError("apex is not installed") |
| 239 | + logging.info( |
| 240 | + f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}" |
| 241 | + ) |
| 242 | + self.batch_norm_sync_mode = batch_norm_sync_mode |
| 243 | + |
215 | 244 | return self
|
216 | 245 |
|
217 | 246 | def set_hooks(self, hooks: List["ClassyHook"]):
|
@@ -317,7 +346,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
|
317 | 346 | .set_meters(meters)
|
318 | 347 | .set_amp_opt_level(amp_opt_level)
|
319 | 348 | .set_distributed_options(
|
320 |
| - BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")] |
| 349 | + broadcast_buffers_mode=BroadcastBuffersMode[ |
| 350 | + config.get("broadcast_buffers", "disabled").upper() |
| 351 | + ], |
| 352 | + batch_norm_sync_mode=BatchNormSyncMode[ |
| 353 | + config.get("batch_norm_sync_mode", "disabled").upper() |
| 354 | + ], |
321 | 355 | )
|
322 | 356 | )
|
323 | 357 | for phase_type in phase_types:
|
@@ -494,6 +528,11 @@ def prepare(
|
494 | 528 | multiprocessing_context=dataloader_mp_context,
|
495 | 529 | )
|
496 | 530 |
|
| 531 | + if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: |
| 532 | + self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model) |
| 533 | + elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: |
| 534 | + self.base_model = apex.parallel.convert_syncbn_model(self.base_model) |
| 535 | + |
497 | 536 | # move the model and loss to the right device
|
498 | 537 | if use_gpu:
|
499 | 538 | self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
|
|
0 commit comments