Skip to content

Commit 46fa38d

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Overlap comms on backward pass
Summary: Resolves issues around cuda streams / NCCL Deadlock with autograd. Basically create seperate streams per pipelined embedding arch. Differential Revision: D58220332
1 parent 0cfae1f commit 46fa38d

File tree

2 files changed

+103
-43
lines changed

2 files changed

+103
-43
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -554,15 +554,8 @@ def __init__(
554554
self._stash_gradients = stash_gradients
555555

556556
# use two data streams to support two concurrent batches
557-
self._embedding_odd_stream: Optional[torch.cuda.streams.Stream] = (
558-
(torch.cuda.Stream(priority=0)) if device.type == "cuda" else None
559-
)
560-
self._embedding_even_stream: Optional[torch.cuda.streams.Stream] = (
561-
(torch.cuda.Stream(priority=0)) if device.type == "cuda" else None
562-
)
563-
self._overarch_stream: Optional[torch.cuda.streams.Stream] = (
564-
(torch.cuda.Stream(priority=-1)) if device.type == "cuda" else None
565-
)
557+
self._embedding_odd_streams: List[Optional[torch.cuda.streams.Stream]] = []
558+
self._embedding_even_streams: List[Optional[torch.cuda.streams.Stream]] = []
566559
self._gradients: Dict[str, torch.Tensor] = {}
567560

568561
def _grad_swap(self) -> None:
@@ -572,6 +565,25 @@ def _grad_swap(self) -> None:
572565
self._gradients[name] = param.grad.clone()
573566
param.grad = grad
574567

568+
def _init_embedding_streams(self) -> None:
569+
570+
for _ in self._pipelined_modules:
571+
self._embedding_odd_streams.append(
572+
torch.cuda.Stream(priority=0) if self._device.type == "cuda" else None
573+
)
574+
self._embedding_even_streams.append(
575+
torch.cuda.Stream(priority=0) if self._device.type == "cuda" else None
576+
)
577+
578+
def _validate_optimizer(self) -> None:
579+
for pipelined_module in self._pipelined_modules:
580+
pipelined_params = set(pipelined_module.parameters())
581+
for group in self._optimizer.param_groups:
582+
if not set(group["params"]).isdisjoint(pipelined_params):
583+
logger.warning(
584+
f"SemiSync pipelined {type(pipelined_module)} and MLP optimizer share parameters"
585+
)
586+
575587
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
576588
# pipeline is already filled
577589
if len(self.batches) >= 3:
@@ -591,7 +603,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
591603
# pyre-ignore [6]
592604
EmbeddingPipelinedForward,
593605
)
606+
self._init_embedding_streams()
594607
self.wait_sparse_data_dist(self.contexts[0])
608+
self._validate_optimizer()
595609
# pyre-ignore [6]
596610
self.start_embedding_lookup(self.batches[0], self.contexts[0])
597611

@@ -645,25 +659,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
645659
self.wait_sparse_data_dist(self.contexts[2])
646660

647661
if self._model.training:
648-
# backward would put an implicit sync point in stream called from, ideally
649-
# this would different from optimizer so it could start earilier, but currently not safe to do so.
650-
with torch.cuda.stream(self._overarch_stream):
651-
with record_function(f"## backward {self.contexts[0].index} ##"):
652-
torch.sum(losses, dim=0).backward()
653-
654-
with record_function(
655-
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
656-
):
657-
if self.is_semi_sync() and self._stash_gradients:
658-
self._grad_swap()
659-
self._mlp_optimizer_step()
662+
with record_function(f"## backward {self.contexts[0].index} ##"):
663+
torch.sum(losses, dim=0).backward()
664+
# pyre-ignore [6]
665+
self.embedding_backward(self.contexts[0])
660666

661-
with record_function(
662-
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
663-
):
664-
self._optimizer.zero_grad()
667+
with record_function(
668+
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
669+
):
670+
if self.is_semi_sync() and self._stash_gradients:
671+
self._grad_swap()
672+
self._mlp_optimizer_step()
673+
674+
with record_function(
675+
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
676+
):
677+
self._optimizer.zero_grad()
665678

666679
if len(self.batches) >= 2 and not self.is_semi_sync():
680+
torch.cuda.synchronize() # needed to avoid race condition
667681
# pyre-ignore [6]
668682
self.start_embedding_lookup(self.batches[1], self.contexts[1])
669683

@@ -674,10 +688,26 @@ def _mlp_forward(
674688
self, batch: In, context: TrainPipelineContext
675689
) -> Tuple[torch.Tensor, Out]:
676690
with record_function(f"## forward {context.index} ##"):
677-
with torch.cuda.stream(self._overarch_stream):
678-
_wait_for_event(batch, context.event)
679-
context.event = None
680-
return cast(Tuple[torch.Tensor, Out], self._model(batch))
691+
_wait_for_event(batch, context.event)
692+
context.event = None
693+
return cast(Tuple[torch.Tensor, Out], self._model(batch))
694+
695+
def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
696+
default_stream = torch.cuda.current_stream()
697+
streams = (
698+
self._embedding_even_streams
699+
if cast(int, context.index) % 2 == 0
700+
else self._embedding_odd_streams
701+
)
702+
for stream, emb_tensor, grad_tensor in zip(
703+
streams,
704+
context.embedding_tensors,
705+
context.detached_embedding_tensors,
706+
):
707+
with torch.cuda.stream(stream):
708+
# pyre-ignore
709+
stream.wait_stream(default_stream)
710+
torch.autograd.backward(emb_tensor, grad_tensor.grad)
681711

682712
def copy_batch_to_gpu(
683713
self,
@@ -722,15 +752,19 @@ def start_embedding_lookup(
722752
if batch is None:
723753
return
724754
with record_function(f"## start_embedding_lookup {context.index} ##"):
725-
with torch.cuda.stream(
726-
self._embedding_even_stream
727-
if cast(int, context.index) % 2 == 0
728-
else self._embedding_odd_stream
729-
):
730-
_wait_for_event(batch, context.event)
731-
_start_embedding_lookup(self._pipelined_modules, batch, context)
732-
context.event = torch.cuda.Event()
733-
context.event.record()
755+
_wait_for_event(batch, context.event)
756+
context.event = []
757+
for i, module in enumerate(self._pipelined_modules):
758+
with torch.cuda.stream(
759+
self._embedding_even_streams[i]
760+
if cast(int, context.index) % 2 == 0
761+
else self._embedding_odd_streams[i]
762+
):
763+
_start_embedding_lookup([module], batch, context)
764+
event = torch.cuda.Event()
765+
event.record()
766+
# pyre-ignore [16]
767+
context.event.append(event)
734768

735769

736770
class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):

torchrec/distributed/train_pipeline/utils.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class TrainPipelineContext:
9696
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
9797
field(default_factory=list)
9898
)
99-
event: Optional[torch.cuda.Event] = None
99+
event: Optional[Union[List[torch.cuda.Event], torch.cuda.Event]] = None
100100
index: Optional[int] = None
101101
version: int = (
102102
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
@@ -114,6 +114,8 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
114114
@dataclass
115115
class EmbeddingTrainPipelineContext(TrainPipelineContext):
116116
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
117+
embedding_tensors: List[torch.Tensor] = field(default_factory=list)
118+
detached_embedding_tensors: List[torch.Tensor] = field(default_factory=list)
117119

118120

119121
@dataclass
@@ -230,8 +232,25 @@ def __call__(self, *input, **kwargs) -> Awaitable:
230232
if self._stream is not None:
231233
torch.cuda.current_stream().wait_stream(self._stream)
232234
cur_stream = torch.cuda.current_stream()
233-
ctx.record_stream(cur_stream)
234-
return self._context.embedding_a2a_requests.pop(self._name)
235+
awaitable = self._context.embedding_a2a_requests.pop(self._name)
236+
embs = awaitable.wait()
237+
if isinstance(embs, Dict):
238+
for jt in embs.values():
239+
tensor = jt.values()
240+
new_tensor = tensor.detach().requires_grad_()
241+
jt._values = new_tensor
242+
# pyre-ignore [16]
243+
self._context.embedding_tensors.append(tensor)
244+
# pyre-ignore [16]
245+
self._context.detached_embedding_tensors.append(new_tensor)
246+
else:
247+
tensor = embs.values()
248+
new_tensor = tensor.detach().requires_grad_()
249+
embs._values = new_tensor
250+
self._context.embedding_tensors.append(tensor)
251+
self._context.detached_embedding_tensors.append(new_tensor)
252+
253+
return embs
235254

236255

237256
class PrefetchPipelinedForward(BaseForward):
@@ -373,11 +392,18 @@ def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> N
373392
batch.record_stream(cur_stream)
374393

375394

376-
def _wait_for_event(batch: In, event: Optional[torch.cuda.Event]) -> None:
395+
def _wait_for_event(
396+
batch: In, event: Optional[Union[List[torch.cuda.Event], torch.cuda.Event]]
397+
) -> None:
377398
"""
378399
Wait for event
379400
"""
380-
if event is not None:
401+
402+
if event and isinstance(event, list):
403+
for sub_event in event:
404+
sub_event.wait()
405+
elif event:
406+
# pyre-ignore
381407
event.wait()
382408
cur_stream = torch.cuda.current_stream()
383409

0 commit comments

Comments
 (0)