Skip to content

Commit fe6c2ed

Browse files
Myle Ottfacebook-github-bot
authored andcommitted
Deprecate --fast-stat-sync and replace with Criterion.logging_outputs_can_be_summed
Summary: Pull Request resolved: fairinternal/fairseq-py#980 Differential Revision: D19351116 Pulled By: myleott fbshipit-source-id: a67b10637f53a80c37b0ce90eb27ced9709871db
1 parent c9a7c06 commit fe6c2ed

12 files changed

+167
-70
lines changed

fairseq/criterions/adaptive_loss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,12 @@ def aggregate_logging_outputs(logging_outputs):
9090
if sample_size != ntokens:
9191
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.
9292
return agg_output
93+
94+
@staticmethod
95+
def logging_outputs_can_be_summed() -> bool:
96+
"""
97+
Whether the logging outputs returned by `forward` can be summed
98+
across workers prior to calling `aggregate_logging_outputs`.
99+
Setting this to True will improves distributed training speed.
100+
"""
101+
return True

fairseq/criterions/cross_entropy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,12 @@ def aggregate_logging_outputs(logging_outputs):
6565
if sample_size != ntokens:
6666
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
6767
return agg_output
68+
69+
@staticmethod
70+
def logging_outputs_can_be_summed() -> bool:
71+
"""
72+
Whether the logging outputs returned by `forward` can be summed
73+
across workers prior to calling `aggregate_logging_outputs`.
74+
Setting this to True will improves distributed training speed.
75+
"""
76+
return True

fairseq/criterions/fairseq_criterion.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,12 @@ def forward(self, model, sample, reduce=True):
3737
def aggregate_logging_outputs(logging_outputs):
3838
"""Aggregate logging outputs from data parallel training."""
3939
raise NotImplementedError
40+
41+
@staticmethod
42+
def logging_outputs_can_be_summed() -> bool:
43+
"""
44+
Whether the logging outputs returned by `forward` can be summed
45+
across workers prior to calling `aggregate_logging_outputs`.
46+
Setting this to True will improves distributed training speed.
47+
"""
48+
return False

fairseq/criterions/label_smoothed_cross_entropy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs):
8888
'nsentences': nsentences,
8989
'sample_size': sample_size,
9090
}
91+
92+
@staticmethod
93+
def logging_outputs_can_be_summed() -> bool:
94+
"""
95+
Whether the logging outputs returned by `forward` can be summed
96+
across workers prior to calling `aggregate_logging_outputs`.
97+
Setting this to True will improves distributed training speed.
98+
"""
99+
return True

fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,12 @@ def aggregate_logging_outputs(logging_outputs):
8888
'nsentences': nsentences,
8989
'sample_size': sample_size,
9090
}
91+
92+
@staticmethod
93+
def logging_outputs_can_be_summed() -> bool:
94+
"""
95+
Whether the logging outputs returned by `forward` can be summed
96+
across workers prior to calling `aggregate_logging_outputs`.
97+
Setting this to True will improves distributed training speed.
98+
"""
99+
return True

fairseq/criterions/legacy_masked_lm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,12 @@ def aggregate_logging_outputs(logging_outputs):
145145
'sample_size': sample_size,
146146
}
147147
return agg_output
148+
149+
@staticmethod
150+
def logging_outputs_can_be_summed() -> bool:
151+
"""
152+
Whether the logging outputs returned by `forward` can be summed
153+
across workers prior to calling `aggregate_logging_outputs`.
154+
Setting this to True will improves distributed training speed.
155+
"""
156+
return True

fairseq/criterions/masked_lm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,12 @@ def aggregate_logging_outputs(logging_outputs):
7979
'sample_size': sample_size,
8080
}
8181
return agg_output
82+
83+
@staticmethod
84+
def logging_outputs_can_be_summed() -> bool:
85+
"""
86+
Whether the logging outputs returned by `forward` can be summed
87+
across workers prior to calling `aggregate_logging_outputs`.
88+
Setting this to True will improves distributed training speed.
89+
"""
90+
return True

fairseq/criterions/nat_loss.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,12 @@ def aggregate_logging_outputs(logging_outputs):
167167
)
168168

169169
return results
170+
171+
@staticmethod
172+
def logging_outputs_can_be_summed() -> bool:
173+
"""
174+
Whether the logging outputs returned by `forward` can be summed
175+
across workers prior to calling `aggregate_logging_outputs`.
176+
Setting this to True will improves distributed training speed.
177+
"""
178+
return True

fairseq/criterions/sentence_prediction.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,12 @@ def aggregate_logging_outputs(logging_outputs):
9696
if sample_size != ntokens:
9797
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
9898
return agg_output
99+
100+
@staticmethod
101+
def logging_outputs_can_be_summed() -> bool:
102+
"""
103+
Whether the logging outputs returned by `forward` can be summed
104+
across workers prior to calling `aggregate_logging_outputs`.
105+
Setting this to True will improves distributed training speed.
106+
"""
107+
return True

fairseq/criterions/sentence_ranking.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,12 @@ def aggregate_logging_outputs(logging_outputs):
115115
if sample_size != ntokens:
116116
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
117117
return agg_output
118+
119+
@staticmethod
120+
def logging_outputs_can_be_summed() -> bool:
121+
"""
122+
Whether the logging outputs returned by `forward` can be summed
123+
across workers prior to calling `aggregate_logging_outputs`.
124+
Setting this to True will improves distributed training speed.
125+
"""
126+
return True

fairseq/options.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,7 @@ def add_distributed_training_args(parser):
345345
help='disable unused parameter detection (not applicable to '
346346
'no_c10d ddp-backend')
347347
group.add_argument('--fast-stat-sync', default=False, action='store_true',
348-
help='Enable fast sync of stats between nodes, this hardcodes to '
349-
'sync only some default stats from logging_output.')
348+
help='[deprecated] this is now defined per Criterion')
350349
# fmt: on
351350
return group
352351

fairseq/trainer.py

Lines changed: 76 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,19 @@ def __init__(self, args, task, model, criterion, dummy_batch=None, oom_batch=Non
5353
self._num_updates = 0
5454
self._optim_history = None
5555
self._optimizer = None
56-
self._prev_grad_norm = None
5756
self._wrapped_criterion = None
5857
self._wrapped_model = None
5958

60-
# Fast stats sync avoids memcpy and is 7% faster when tested on 16 nodes.
61-
# It is less flexible and syncs only the default stats.
62-
self._all_reduce_list = [0.0] * 6
63-
self.fast_stat_sync = args.fast_stat_sync
59+
if self.cuda and args.distributed_world_size > 1:
60+
self._grad_norm_buf = torch.cuda.DoubleTensor(args.distributed_world_size)
61+
else:
62+
self._grad_norm_buf = None
63+
64+
if args.fast_stat_sync:
65+
utils.deprecation_warning(
66+
'--fast-stat-sync is deprecated. If needed, please update your '
67+
'Criterion to define the logging_outputs_can_be_summed() method.'
68+
)
6469

6570
self.init_meters(args)
6671

@@ -294,7 +299,7 @@ def train_step(self, samples, dummy_batch=False, raise_oom=False):
294299
self.meters["train_wall"].start()
295300

296301
# forward and backward pass
297-
logging_outputs, sample_sizes, ooms = [], [], 0
302+
logging_outputs, sample_size, ooms = [], 0, 0
298303
for i, sample in enumerate(samples):
299304
sample = self._prepare_sample(sample)
300305
if sample is None:
@@ -323,22 +328,13 @@ def maybe_no_sync():
323328
try:
324329
with maybe_no_sync():
325330
# forward and backward
326-
loss, sample_size, logging_output = self.task.train_step(
331+
loss, sample_size_i, logging_output = self.task.train_step(
327332
sample, self.model, self.criterion, self.optimizer, ignore_grad
328333
)
329334

330335
if not ignore_grad:
331336
logging_outputs.append(logging_output)
332-
sample_sizes.append(sample_size)
333-
334-
if self.fast_stat_sync:
335-
self._all_reduce_list[0] += sample_size
336-
self._all_reduce_list[1] += logging_output.get(
337-
"nsentences", 0.0
338-
)
339-
self._all_reduce_list[2] += logging_output.get("loss", 0.0)
340-
self._all_reduce_list[3] += logging_output.get("nll_loss", 0.0)
341-
self._all_reduce_list[4] += logging_output.get("ntokens", 0.0)
337+
sample_size += sample_size_i
342338
except RuntimeError as e:
343339
if "out of memory" in str(e):
344340
self._log_oom(e)
@@ -353,71 +349,28 @@ def maybe_no_sync():
353349
else:
354350
raise e
355351

356-
if self.fast_stat_sync:
357-
self._all_reduce_list[5] += ooms
358-
359352
if ooms > 0 and self._oom_batch is not None:
360353
self.handle_ooms(ooms)
361354

362355
if dummy_batch:
363356
return None
364357

365358
# gather logging outputs from all replicas
366-
if self.fast_stat_sync:
367-
# rework all_gather_list
368-
all_reduce_list_tensor = torch.cuda.DoubleTensor(self._all_reduce_list)
369-
if self._sync_stats():
370-
torch.distributed.all_reduce(all_reduce_list_tensor)
371-
# Normalize loss and nll_loss by "sample_size"
372-
# and convert to log base 2
373-
all_reduce_list_tensor[2:4].div_(
374-
(all_reduce_list_tensor[0:1] * torch.log(torch.cuda.DoubleTensor([2])))
359+
if self._sync_stats():
360+
logging_outputs, sample_size, ooms = self._aggregate_logging_outputs(
361+
logging_outputs, sample_size, ooms,
375362
)
376-
self._all_reduce_list = all_reduce_list_tensor.tolist()
377-
logging_output = {}
378-
[
379-
sample_size,
380-
logging_output["nsentences"],
381-
logging_output["loss"],
382-
logging_output["nll_loss"],
383-
logging_output["ntokens"],
384-
ooms,
385-
] = self._all_reduce_list
386-
elif self._sync_stats():
387-
logging_outputs, sample_sizes, ooms, prev_norms = zip(
388-
*distributed_utils.all_gather_list(
389-
[logging_outputs, sample_sizes, ooms, self._prev_grad_norm],
390-
max_size=getattr(self.args, 'all_gather_list_size', 16384),
391-
)
392-
)
393-
logging_outputs = list(chain.from_iterable(logging_outputs))
394-
sample_sizes = list(chain.from_iterable(sample_sizes))
395-
ooms = sum(ooms)
396-
397-
if not self.args.use_bmuf:
398-
norms = [norm for norm in prev_norms if norm is not None]
399-
if not (
400-
all(norm == norms[0] for norm in norms)
401-
or all(math.isnan(norm) or math.isinf(norm) for norm in norms)
402-
):
403-
raise RuntimeError(
404-
"Fatal error: gradients are inconsistent between workers. "
405-
"Try --ddp-backend=no_c10d, which is a more robust but "
406-
"slightly slower DDP implementation."
407-
)
408363

409364
self.meters["oom"].update(ooms, len(samples))
410365
if ooms == self.args.distributed_world_size * len(samples):
411366
print("| WARNING: OOM in all workers, skipping update")
412367
self.zero_grad()
413368
return None
414369

415-
if not self.fast_stat_sync:
416-
# aggregate logging outputs and sample sizes
417-
logging_output = self.task.aggregate_logging_outputs(
418-
logging_outputs, self.get_criterion()
419-
)
420-
sample_size = sum(sample_sizes)
370+
# aggregate logging outputs and sample sizes
371+
logging_output = self.task.aggregate_logging_outputs(
372+
logging_outputs, self.get_criterion()
373+
)
421374

422375
if not all(k in logging_output for k in ["ntokens", "nsentences"]):
423376
raise Exception(
@@ -442,7 +395,9 @@ def maybe_no_sync():
442395

443396
# clip grads
444397
grad_norm = self.optimizer.clip_grad_norm(self.args.clip_norm)
445-
self._prev_grad_norm = grad_norm
398+
399+
# check that grad norms are consistent across workers
400+
self._check_grad_norms(grad_norm)
446401

447402
# take an optimization step
448403
self.optimizer.step()
@@ -679,3 +634,56 @@ def _log_oom(self, exc):
679634
for device_idx in range(torch.cuda.device_count()):
680635
print(torch.cuda.memory_summary(device=device_idx), file=sys.stderr)
681636
sys.stderr.flush()
637+
638+
def _aggregate_logging_outputs(self, logging_outputs, *extra_stats_to_sum):
639+
if self.get_criterion().__class__.logging_outputs_can_be_summed():
640+
return self._fast_stat_sync_sum(logging_outputs, *extra_stats_to_sum)
641+
else:
642+
return self._all_gather_list_sync(logging_outputs, *extra_stats_to_sum)
643+
644+
def _all_gather_list_sync(self, logging_outputs, *extra_stats_to_sum):
645+
"""
646+
Sync logging outputs across workers. all_gather_list_sync is
647+
suitable when logging outputs are complex types.
648+
"""
649+
results = list(zip(
650+
*distributed_utils.all_gather_list(
651+
[logging_outputs] + list(extra_stats_to_sum),
652+
max_size=getattr(self.args, 'all_gather_list_size', 16384),
653+
)
654+
))
655+
logging_outputs, extra_stats_to_sum = results[0], results[1:]
656+
logging_outputs = list(chain.from_iterable(logging_outputs))
657+
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum]
658+
return [logging_outputs] + extra_stats_to_sum
659+
660+
def _fast_stat_sync_sum(self, logging_outputs, *extra_stats_to_sum):
661+
"""
662+
Sync logging outputs across workers. fast_stat_sync_sum is
663+
faster than all_gather_list_sync, but is only suitable when
664+
logging outputs are scalars and can be summed.
665+
"""
666+
sorted_keys = sorted(logging_outputs[0].keys())
667+
num_extra = len(extra_stats_to_sum)
668+
stats = list(extra_stats_to_sum) + [
669+
sum(log.get(k, 0) for log in logging_outputs)
670+
for k in sorted_keys
671+
]
672+
buf = torch.cuda.DoubleTensor(stats)
673+
distributed_utils.all_reduce(buf)
674+
buf = buf.tolist()
675+
extra_stats_to_sum, stats = buf[:num_extra], buf[num_extra:]
676+
stats = [{k: stats[i] for i, k in enumerate(sorted_keys)}]
677+
return [stats] + extra_stats_to_sum
678+
679+
def _check_grad_norms(self, grad_norm):
680+
"""Check that grad norms are consistent across workers."""
681+
if self._grad_norm_buf is not None:
682+
self._grad_norm_buf.zero_()
683+
self._grad_norm_buf[self.args.distributed_rank] = grad_norm
684+
distributed_utils.all_reduce(self._grad_norm_buf)
685+
if not (self._grad_norm_buf == self._grad_norm_buf[0]).all():
686+
raise RuntimeError(
687+
"Fatal error: gradients are inconsistent between workers. "
688+
"Try --ddp-backend=no_c10d."
689+
)

0 commit comments

Comments
 (0)