@@ -134,6 +134,7 @@ def __init__(self):
134
134
self .amp_opt_level = None
135
135
self .perf_log = []
136
136
self .last_batch = None
137
+ self .sync_batch_norm = False
137
138
138
139
def set_checkpoint (self , checkpoint ):
139
140
"""Sets checkpoint on task.
@@ -282,6 +283,29 @@ def set_amp_opt_level(self, opt_level: Optional[str]):
282
283
logging .info (f"AMP enabled with opt_level { opt_level } " )
283
284
return self
284
285
286
+ def set_sync_batch_norm (self , sync_batch_norm : bool ) -> "ClassificationTask" :
287
+ """Enable / disable sync batch norm.
288
+
289
+ Args:
290
+ sync_batch_norm: Set to True to enable and False otherwise.
291
+ Raises:
292
+ RuntimeError: If sync_batch_norm is True and apex is not installed.
293
+
294
+ Warning: apex needs to be installed to utilize this feature.
295
+ """
296
+ self .sync_batch_norm = sync_batch_norm
297
+ if sync_batch_norm :
298
+ """
299
+ if not apex_available:
300
+ raise RuntimeError(
301
+ "apex is not installed, cannot enable sync_batch_norm"
302
+ )
303
+ """
304
+ logging .info ("Using Synchronized Batch Normalization" )
305
+ else :
306
+ logging .info ("Synchronized Batch Normalization is disabled" )
307
+ return self
308
+
285
309
@classmethod
286
310
def from_config (cls , config : Dict [str , Any ]) -> "ClassificationTask" :
287
311
"""Instantiates a ClassificationTask from a configuration.
@@ -303,6 +327,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
303
327
loss = build_loss (config ["loss" ])
304
328
test_only = config .get ("test_only" , False )
305
329
amp_opt_level = config .get ("amp_opt_level" )
330
+ sync_batch_norm = config .get ("sync_batch_norm" , False )
306
331
meters = build_meters (config .get ("meters" , {}))
307
332
model = build_model (config ["model" ])
308
333
# put model in eval mode in case any hooks modify model states, it'll
@@ -320,6 +345,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
320
345
.set_optimizer (optimizer )
321
346
.set_meters (meters )
322
347
.set_amp_opt_level (amp_opt_level )
348
+ .set_sync_batch_norm (sync_batch_norm )
323
349
.set_distributed_options (
324
350
BroadcastBuffersMode [config .get ("broadcast_buffers" , "DISABLED" )]
325
351
)
@@ -498,6 +524,10 @@ def prepare(
498
524
multiprocessing_context = dataloader_mp_context ,
499
525
)
500
526
527
+ if self .sync_batch_norm :
528
+ # self.base_model = apex.parallel.convert_syncbn_model(self.base_model)
529
+ self .base_model = torch .nn .SyncBatchNorm .convert_sync_batchnorm (self .base_model )
530
+
501
531
# move the model and loss to the right device
502
532
if use_gpu :
503
533
self .base_model , self .loss = copy_model_to_gpu (self .base_model , self .loss )
@@ -585,6 +615,8 @@ def get_classy_state(self, deep_copy: bool = False):
585
615
Args:
586
616
deep_copy: If true, does a deep copy of state before returning.
587
617
"""
618
+ from classy_vision .generic .distributed_util import get_world_size
619
+ print ("World size" , get_world_size ())
588
620
classy_state_dict = {
589
621
"train" : self .train ,
590
622
"base_model" : self .base_model .get_classy_state (),
0 commit comments