@@ -53,14 +53,19 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
53
53
self ._num_updates = 0
54
54
self ._optim_history = None
55
55
self ._optimizer = None
56
- self ._prev_grad_norm = None
57
56
self ._wrapped_criterion = None
58
57
self ._wrapped_model = None
59
58
60
- # Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
61
- # It is less flexible and syncs only the default stats.
62
- self ._all_reduce_list = [0.0 ] * 6
63
- self .fast_stat_sync = args .fast_stat_sync
59
+ if self .cuda and args .distributed_world_size > 1 :
60
+ self ._grad_norm_buf = torch .cuda .DoubleTensor (args .distributed_world_size )
61
+ else :
62
+ self ._grad_norm_buf = None
63
+
64
+ if args .fast_stat_sync :
65
+ utils .deprecation_warning (
66
+ '--fast-stat-sync is deprecated. If needed, please update your '
67
+ 'Criterion to define the logging_outputs_can_be_summed() method.'
68
+ )
64
69
65
70
self .init_meters (args )
66
71
@@ -294,7 +299,7 @@ def train_step(self, samples, dummy_batch=False, raise_oom=False):
294
299
self .meters ["train_wall" ].start ()
295
300
296
301
# forward and backward pass
297
- logging_outputs , sample_sizes , ooms = [], [] , 0
302
+ logging_outputs , sample_size , ooms = [], 0 , 0
298
303
for i , sample in enumerate (samples ):
299
304
sample = self ._prepare_sample (sample )
300
305
if sample is None :
@@ -323,22 +328,13 @@ def maybe_no_sync():
323
328
try :
324
329
with maybe_no_sync ():
325
330
# forward and backward
326
- loss , sample_size , logging_output = self .task .train_step (
331
+ loss , sample_size_i , logging_output = self .task .train_step (
327
332
sample , self .model , self .criterion , self .optimizer , ignore_grad
328
333
)
329
334
330
335
if not ignore_grad :
331
336
logging_outputs .append (logging_output )
332
- sample_sizes .append (sample_size )
333
-
334
- if self .fast_stat_sync :
335
- self ._all_reduce_list [0 ] += sample_size
336
- self ._all_reduce_list [1 ] += logging_output .get (
337
- "nsentences" , 0.0
338
- )
339
- self ._all_reduce_list [2 ] += logging_output .get ("loss" , 0.0 )
340
- self ._all_reduce_list [3 ] += logging_output .get ("nll_loss" , 0.0 )
341
- self ._all_reduce_list [4 ] += logging_output .get ("ntokens" , 0.0 )
337
+ sample_size += sample_size_i
342
338
except RuntimeError as e :
343
339
if "out of memory" in str (e ):
344
340
self ._log_oom (e )
@@ -353,71 +349,28 @@ def maybe_no_sync():
353
349
else :
354
350
raise e
355
351
356
- if self .fast_stat_sync :
357
- self ._all_reduce_list [5 ] += ooms
358
-
359
352
if ooms > 0 and self ._oom_batch is not None :
360
353
self .handle_ooms (ooms )
361
354
362
355
if dummy_batch :
363
356
return None
364
357
365
358
# gather logging outputs from all replicas
366
- if self .fast_stat_sync :
367
- # rework all_gather_list
368
- all_reduce_list_tensor = torch .cuda .DoubleTensor (self ._all_reduce_list )
369
- if self ._sync_stats ():
370
- torch .distributed .all_reduce (all_reduce_list_tensor )
371
- # Normalize loss and nll_loss by "sample_size"
372
- # and convert to log base 2
373
- all_reduce_list_tensor [2 :4 ].div_ (
374
- (all_reduce_list_tensor [0 :1 ] * torch .log (torch .cuda .DoubleTensor ([2 ])))
359
+ if self ._sync_stats ():
360
+ logging_outputs , sample_size , ooms = self ._aggregate_logging_outputs (
361
+ logging_outputs , sample_size , ooms ,
375
362
)
376
- self ._all_reduce_list = all_reduce_list_tensor .tolist ()
377
- logging_output = {}
378
- [
379
- sample_size ,
380
- logging_output ["nsentences" ],
381
- logging_output ["loss" ],
382
- logging_output ["nll_loss" ],
383
- logging_output ["ntokens" ],
384
- ooms ,
385
- ] = self ._all_reduce_list
386
- elif self ._sync_stats ():
387
- logging_outputs , sample_sizes , ooms , prev_norms = zip (
388
- * distributed_utils .all_gather_list (
389
- [logging_outputs , sample_sizes , ooms , self ._prev_grad_norm ],
390
- max_size = getattr (self .args , 'all_gather_list_size' , 16384 ),
391
- )
392
- )
393
- logging_outputs = list (chain .from_iterable (logging_outputs ))
394
- sample_sizes = list (chain .from_iterable (sample_sizes ))
395
- ooms = sum (ooms )
396
-
397
- if not self .args .use_bmuf :
398
- norms = [norm for norm in prev_norms if norm is not None ]
399
- if not (
400
- all (norm == norms [0 ] for norm in norms )
401
- or all (math .isnan (norm ) or math .isinf (norm ) for norm in norms )
402
- ):
403
- raise RuntimeError (
404
- "Fatal error: gradients are inconsistent between workers. "
405
- "Try --ddp-backend=no_c10d, which is a more robust but "
406
- "slightly slower DDP implementation."
407
- )
408
363
409
364
self .meters ["oom" ].update (ooms , len (samples ))
410
365
if ooms == self .args .distributed_world_size * len (samples ):
411
366
print ("| WARNING: OOM in all workers, skipping update" )
412
367
self .zero_grad ()
413
368
return None
414
369
415
- if not self .fast_stat_sync :
416
- # aggregate logging outputs and sample sizes
417
- logging_output = self .task .aggregate_logging_outputs (
418
- logging_outputs , self .get_criterion ()
419
- )
420
- sample_size = sum (sample_sizes )
370
+ # aggregate logging outputs and sample sizes
371
+ logging_output = self .task .aggregate_logging_outputs (
372
+ logging_outputs , self .get_criterion ()
373
+ )
421
374
422
375
if not all (k in logging_output for k in ["ntokens" , "nsentences" ]):
423
376
raise Exception (
@@ -442,7 +395,9 @@ def maybe_no_sync():
442
395
443
396
# clip grads
444
397
grad_norm = self .optimizer .clip_grad_norm (self .args .clip_norm )
445
- self ._prev_grad_norm = grad_norm
398
+
399
+ # check that grad norms are consistent across workers
400
+ self ._check_grad_norms (grad_norm )
446
401
447
402
# take an optimization step
448
403
self .optimizer .step ()
@@ -679,3 +634,56 @@ def _log_oom(self, exc):
679
634
for device_idx in range (torch .cuda .device_count ()):
680
635
print (torch .cuda .memory_summary (device = device_idx ), file = sys .stderr )
681
636
sys .stderr .flush ()
637
+
638
+ def _aggregate_logging_outputs (self , logging_outputs , * extra_stats_to_sum ):
639
+ if self .get_criterion ().__class__ .logging_outputs_can_be_summed ():
640
+ return self ._fast_stat_sync_sum (logging_outputs , * extra_stats_to_sum )
641
+ else :
642
+ return self ._all_gather_list_sync (logging_outputs , * extra_stats_to_sum )
643
+
644
+ def _all_gather_list_sync (self , logging_outputs , * extra_stats_to_sum ):
645
+ """
646
+ Sync logging outputs across workers. all_gather_list_sync is
647
+ suitable when logging outputs are complex types.
648
+ """
649
+ results = list (zip (
650
+ * distributed_utils .all_gather_list (
651
+ [logging_outputs ] + list (extra_stats_to_sum ),
652
+ max_size = getattr (self .args , 'all_gather_list_size' , 16384 ),
653
+ )
654
+ ))
655
+ logging_outputs , extra_stats_to_sum = results [0 ], results [1 :]
656
+ logging_outputs = list (chain .from_iterable (logging_outputs ))
657
+ extra_stats_to_sum = [sum (s ) for s in extra_stats_to_sum ]
658
+ return [logging_outputs ] + extra_stats_to_sum
659
+
660
+ def _fast_stat_sync_sum (self , logging_outputs , * extra_stats_to_sum ):
661
+ """
662
+ Sync logging outputs across workers. fast_stat_sync_sum is
663
+ faster than all_gather_list_sync, but is only suitable when
664
+ logging outputs are scalars and can be summed.
665
+ """
666
+ sorted_keys = sorted (logging_outputs [0 ].keys ())
667
+ num_extra = len (extra_stats_to_sum )
668
+ stats = list (extra_stats_to_sum ) + [
669
+ sum (log .get (k , 0 ) for log in logging_outputs )
670
+ for k in sorted_keys
671
+ ]
672
+ buf = torch .cuda .DoubleTensor (stats )
673
+ distributed_utils .all_reduce (buf )
674
+ buf = buf .tolist ()
675
+ extra_stats_to_sum , stats = buf [:num_extra ], buf [num_extra :]
676
+ stats = [{k : stats [i ] for i , k in enumerate (sorted_keys )}]
677
+ return [stats ] + extra_stats_to_sum
678
+
679
+ def _check_grad_norms (self , grad_norm ):
680
+ """Check that grad norms are consistent across workers."""
681
+ if self ._grad_norm_buf is not None :
682
+ self ._grad_norm_buf .zero_ ()
683
+ self ._grad_norm_buf [self .args .distributed_rank ] = grad_norm
684
+ distributed_utils .all_reduce (self ._grad_norm_buf )
685
+ if not (self ._grad_norm_buf == self ._grad_norm_buf [0 ]).all ():
686
+ raise RuntimeError (
687
+ "Fatal error: gradients are inconsistent between workers. "
688
+ "Try --ddp-backend=no_c10d."
689
+ )
0 commit comments