|
11 | 11 | from typing import Any, Dict, List, NamedTuple, Optional, Union
|
12 | 12 |
|
13 | 13 | import torch
|
| 14 | +import torch.nn as nn |
14 | 15 | from classy_vision.dataset import ClassyDataset, build_dataset
|
15 | 16 | from classy_vision.generic.distributed_util import (
|
16 | 17 | all_reduce_mean,
|
@@ -53,6 +54,12 @@ class BroadcastBuffersMode(enum.Enum):
|
53 | 54 | BEFORE_EVAL = enum.auto()
|
54 | 55 |
|
55 | 56 |
|
| 57 | +class BatchNormSyncMode(enum.Enum): |
| 58 | + DISABLED = enum.auto() # No Synchronized Batch Normalization |
| 59 | + PYTORCH = enum.auto() # Use torch.nn.SyncBatchNorm |
| 60 | + APEX = enum.auto() # Use apex.parallel.SyncBatchNorm, needs apex to be installed |
| 61 | + |
| 62 | + |
56 | 63 | class LastBatchInfo(NamedTuple):
|
57 | 64 | loss: torch.Tensor
|
58 | 65 | output: torch.Tensor
|
@@ -204,14 +211,35 @@ def set_meters(self, meters: List["ClassyMeter"]):
|
204 | 211 | self.meters = meters
|
205 | 212 | return self
|
206 | 213 |
|
207 |
| - def set_distributed_options(self, broadcast_buffers_mode: BroadcastBuffersMode): |
| 214 | + def set_distributed_options( |
| 215 | + self, |
| 216 | + broadcast_buffers_mode: BroadcastBuffersMode = BroadcastBuffersMode.DISABLED, |
| 217 | + batch_norm_sync_mode: BatchNormSyncMode = BatchNormSyncMode.DISABLED, |
| 218 | + ): |
208 | 219 | """Set distributed options.
|
209 | 220 |
|
210 | 221 | Args:
|
211 | 222 | broadcast_buffers_mode: Broadcast buffers mode. See
|
212 | 223 | :class:`BroadcastBuffersMode` for options.
|
| 224 | + batch_norm_sync_mode: Batch normalization synchronization mode. See |
| 225 | + :class:`BatchNormSyncMode` for options. |
| 226 | +
|
| 227 | + Raises: |
| 228 | + RuntimeError: If batch_norm_sync_mode is `BatchNormSyncMode.APEX` and apex |
| 229 | + is not installed. |
213 | 230 | """
|
214 | 231 | self.broadcast_buffers_mode = broadcast_buffers_mode
|
| 232 | + |
| 233 | + if batch_norm_sync_mode == BatchNormSyncMode.DISABLED: |
| 234 | + logging.info("Synchronized Batch Normalization is disabled") |
| 235 | + else: |
| 236 | + if batch_norm_sync_mode == BatchNormSyncMode.APEX and not apex_available: |
| 237 | + raise RuntimeError("apex is not installed") |
| 238 | + logging.info( |
| 239 | + f"Using Synchronized Batch Normalization using {batch_norm_sync_mode}" |
| 240 | + ) |
| 241 | + self.batch_norm_sync_mode = batch_norm_sync_mode |
| 242 | + |
215 | 243 | return self
|
216 | 244 |
|
217 | 245 | def set_hooks(self, hooks: List["ClassyHook"]):
|
@@ -317,7 +345,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
|
317 | 345 | .set_meters(meters)
|
318 | 346 | .set_amp_opt_level(amp_opt_level)
|
319 | 347 | .set_distributed_options(
|
320 |
| - BroadcastBuffersMode[config.get("broadcast_buffers", "DISABLED")] |
| 348 | + broadcast_buffers_mode=BroadcastBuffersMode[ |
| 349 | + config.get("broadcast_buffers", "disabled").upper() |
| 350 | + ], |
| 351 | + batch_norm_sync_mode=BatchNormSyncMode[ |
| 352 | + config.get("batch_norm_sync_mode", "disabled").upper() |
| 353 | + ], |
321 | 354 | )
|
322 | 355 | )
|
323 | 356 | for phase_type in phase_types:
|
@@ -494,6 +527,11 @@ def prepare(
|
494 | 527 | multiprocessing_context=dataloader_mp_context,
|
495 | 528 | )
|
496 | 529 |
|
| 530 | + if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH: |
| 531 | + self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(self.base_model) |
| 532 | + elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX: |
| 533 | + self.base_model = apex.parallel.convert_syncbn_model(self.base_model) |
| 534 | + |
497 | 535 | # move the model and loss to the right device
|
498 | 536 | if use_gpu:
|
499 | 537 | self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
|
|
0 commit comments