Skip to content

Commit 57c6235

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Overlap comms on backward pass (#2117)
Summary: Pull Request resolved: #2117 Resolves issues around cuda streams / NCCL Deadlock with autograd. Basically create seperate streams per pipelined embedding arch. Differential Revision: D58220332
1 parent b8a1c40 commit 57c6235

File tree

2 files changed

+159
-78
lines changed

2 files changed

+159
-78
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 103 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
_start_embedding_lookup,
3939
_to_device,
4040
_wait_for_batch,
41-
_wait_for_event,
41+
_wait_for_events,
4242
DataLoadingThread,
4343
EmbeddingPipelinedForward,
4444
EmbeddingTrainPipelineContext,
@@ -590,8 +590,6 @@ def start_sparse_data_dist(
590590
return
591591
with record_function(f"## start_sparse_data_dist {context.index} ##"):
592592
with self._stream_context(self._data_dist_stream):
593-
if context.event is not None:
594-
context.event.wait()
595593
_wait_for_batch(batch, self._memcpy_stream)
596594

597595
original_contexts = [p.get_context() for p in self._pipelined_preprocs]
@@ -737,11 +735,8 @@ def __init__(
737735
if device.type in ["cuda", "mtia"]
738736
else None
739737
)
740-
self._bwd_sync_stream: Optional[torch.Stream] = (
741-
(torch.get_device_module(self._device).Stream(priority=0))
742-
if device.type in ["cuda", "mtia"]
743-
else None
744-
)
738+
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
739+
self._embedding_even_streams: List[Optional[torch.Stream]] = []
745740
self._gradients: Dict[str, torch.Tensor] = {}
746741

747742
def _grad_swap(self) -> None:
@@ -751,6 +746,29 @@ def _grad_swap(self) -> None:
751746
self._gradients[name] = param.grad.clone()
752747
param.grad = grad
753748

749+
def _init_embedding_streams(self) -> None:
750+
751+
for _ in self._pipelined_modules:
752+
self._embedding_odd_streams.append(
753+
(torch.get_device_module(self._device).Stream(priority=0))
754+
if self._device.type in ["cuda", "mtia"]
755+
else None
756+
)
757+
self._embedding_even_streams.append(
758+
(torch.get_device_module(self._device).Stream(priority=0))
759+
if self._device.type in ["cuda", "mtia"]
760+
else None
761+
)
762+
763+
def _validate_optimizer(self) -> None:
764+
for pipelined_module in self._pipelined_modules:
765+
pipelined_params = set(pipelined_module.parameters())
766+
for group in self._optimizer.param_groups:
767+
if not set(group["params"]).isdisjoint(pipelined_params):
768+
logger.warning(
769+
f"SemiSync pipelined {type(pipelined_module)} and optimizer share parameters"
770+
)
771+
754772
def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
755773
# pipeline is already filled
756774
if len(self.batches) >= 3:
@@ -770,7 +788,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
770788
# pyre-ignore [6]
771789
EmbeddingPipelinedForward,
772790
)
791+
self._init_embedding_streams()
773792
self.wait_sparse_data_dist(self.contexts[0])
793+
self._validate_optimizer()
774794
# pyre-ignore [6]
775795
self.start_embedding_lookup(self.batches[0], self.contexts[0])
776796

@@ -824,26 +844,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
824844
self.wait_sparse_data_dist(self.contexts[2])
825845

826846
if self._model.training:
827-
# backward would put an implicit sync point in stream called from, ideally
828-
# this would different from optimizer so it could start earilier, but currently not safe to do so.
829-
with self._stream_context(self._overarch_stream):
830-
with record_function(f"## backward {self.contexts[0].index} ##"):
831-
torch.sum(losses, dim=0).backward()
832-
833-
with self._stream_context(self._overarch_stream):
834-
with record_function(
835-
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
836-
):
837-
if self.is_semi_sync() and self._stash_gradients:
838-
self._grad_swap()
839-
self._mlp_optimizer_step()
840-
841-
with record_function(
842-
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
843-
):
844-
self._optimizer.zero_grad()
847+
with record_function(f"## backward {self.contexts[0].index} ##"):
848+
torch.sum(losses, dim=0).backward()
849+
# pyre-ignore [6]
850+
self.embedding_backward(self.contexts[0])
851+
852+
with record_function(
853+
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
854+
):
855+
if self.is_semi_sync() and self._stash_gradients:
856+
self._grad_swap()
857+
self._mlp_optimizer_step()
858+
859+
with record_function(
860+
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
861+
):
862+
self._optimizer.zero_grad()
845863

846864
if len(self.batches) >= 2 and not self.is_semi_sync():
865+
torch.cuda.synchronize() # needed to avoid race condition
847866
# pyre-ignore [6]
848867
self.start_embedding_lookup(self.batches[1], self.contexts[1])
849868

@@ -854,10 +873,29 @@ def _mlp_forward(
854873
self, batch: In, context: TrainPipelineContext
855874
) -> Tuple[torch.Tensor, Out]:
856875
with record_function(f"## forward {context.index} ##"):
857-
with self._stream_context(self._overarch_stream):
858-
_wait_for_event(batch, self._device, context.event)
859-
context.event = None
860-
return cast(Tuple[torch.Tensor, Out], self._model(batch))
876+
_wait_for_events(
877+
batch, context, torch.get_device_module(self._device).current_stream()
878+
)
879+
return cast(Tuple[torch.Tensor, Out], self._model(batch))
880+
881+
def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
882+
default_stream = torch.get_device_module(self._device).current_stream()
883+
streams = (
884+
self._embedding_even_streams
885+
if cast(int, context.index) % 2 == 0
886+
else self._embedding_odd_streams
887+
)
888+
for stream, emb_tensors, detached_emb_tensors in zip(
889+
streams,
890+
context.embedding_tensors,
891+
context.detached_embedding_tensors,
892+
):
893+
with self._stream_context(stream):
894+
grads = [tensor.grad for tensor in detached_emb_tensors]
895+
if stream:
896+
stream.wait_stream(default_stream)
897+
# pyre-ignore
898+
torch.autograd.backward(emb_tensors, grads)
861899

862900
def copy_batch_to_gpu(
863901
self,
@@ -870,8 +908,9 @@ def copy_batch_to_gpu(
870908
if batch is not None:
871909
batch = _to_device(batch, self._device, non_blocking=True)
872910
context = self._create_context()
873-
context.event = torch.get_device_module(self._device).Event()
874-
context.event.record()
911+
event = torch.get_device_module(self._device).Event()
912+
event.record()
913+
context.events.append(event)
875914
return batch, context
876915

877916
def start_sparse_data_dist(
@@ -882,9 +921,25 @@ def start_sparse_data_dist(
882921
"""
883922
Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version.
884923
"""
885-
super().start_sparse_data_dist(batch, context)
886-
context.event = torch.get_device_module(self._device).Event()
887-
context.event.record()
924+
if batch is None:
925+
return
926+
927+
# Temporarily set context for next iter to populate cache
928+
original_contexts = [p.get_context() for p in self._pipelined_preprocs]
929+
for preproc_mod in self._pipelined_preprocs:
930+
preproc_mod.set_context(context)
931+
_wait_for_events(
932+
batch, context, torch.get_device_module(self._device).current_stream()
933+
)
934+
_start_data_dist(self._pipelined_modules, batch, context)
935+
936+
# Restore context for model fwd
937+
for module, context in zip(self._pipelined_preprocs, original_contexts):
938+
module.set_context(context)
939+
940+
event = torch.get_device_module(self._device).Event()
941+
event.record()
942+
context.events.append(event)
888943

889944
def start_embedding_lookup(
890945
self,
@@ -897,17 +952,20 @@ def start_embedding_lookup(
897952
if batch is None:
898953
return
899954
with record_function(f"## start_embedding_lookup {context.index} ##"):
900-
with self._stream_context(
901-
self._embedding_even_stream
902-
if cast(int, context.index) % 2 == 0
903-
else self._embedding_odd_stream
904-
):
905-
_wait_for_event(batch, self._device, context.event)
906-
_start_embedding_lookup(
907-
self._pipelined_modules, batch, context, self._device
955+
_wait_for_events(
956+
batch, context, torch.get_device_module(self._device).current_stream()
957+
)
958+
for i, module in enumerate(self._pipelined_modules):
959+
stream = (
960+
self._embedding_even_streams[i]
961+
if cast(int, context.index) % 2 == 0
962+
else self._embedding_odd_streams[i]
908963
)
909-
context.event = torch.get_device_module(self._device).Event()
910-
context.event.record()
964+
with self._stream_context(stream):
965+
_start_embedding_lookup(module, batch, context, stream)
966+
event = torch.get_device_module(self._device).Event()
967+
event.record()
968+
context.events.append(event)
911969

912970

913971
class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):

torchrec/distributed/train_pipeline/utils.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
)
4444
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule
4545

46-
from torchrec.distributed.types import Awaitable
46+
from torchrec.distributed.types import Awaitable, LazyNoWait
4747

48-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
48+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
4949
from torchrec.streamable import Multistreamable, Pipelineable
5050

5151
logger: logging.Logger = logging.getLogger(__name__)
@@ -95,10 +95,8 @@ class TrainPipelineContext:
9595
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
9696
field(default_factory=list)
9797
)
98-
event: Optional[torch.Event] = None
99-
98+
events: List[torch.Event] = field(default_factory=list)
10099
preproc_fwd_results: Dict[str, Any] = field(default_factory=dict)
101-
102100
index: Optional[int] = None
103101
version: int = (
104102
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
@@ -116,6 +114,8 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
116114
@dataclass
117115
class EmbeddingTrainPipelineContext(TrainPipelineContext):
118116
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
117+
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
118+
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
119119

120120

121121
@dataclass
@@ -369,7 +369,35 @@ def __call__(self, *input, **kwargs) -> Awaitable:
369369
)
370370
cur_stream = torch.get_device_module(self._device).current_stream()
371371
ctx.record_stream(cur_stream)
372-
return self._context.embedding_a2a_requests.pop(self._name)
372+
awaitable = self._context.embedding_a2a_requests.pop(self._name)
373+
embeddings = awaitable.wait() # trigger awaitable manually for type checking
374+
tensors = []
375+
detached_tensors = []
376+
if isinstance(embeddings, Dict):
377+
for jt in embeddings.values():
378+
assert isinstance(jt, JaggedTensor)
379+
tensor = jt.values()
380+
detached_tensor = tensor.detach().requires_grad_()
381+
detached_tensor.retain_grad()
382+
jt._values = detached_tensor
383+
tensors.append(tensor)
384+
detached_tensors.append(detached_tensor)
385+
# pyre-ignore [16]
386+
self._context.embedding_tensors.append(tensors)
387+
# pyre-ignore [16]
388+
self._context.detached_embedding_tensors.append(detached_tensors)
389+
else:
390+
assert isinstance(embeddings, KeyedTensor)
391+
tensor = embeddings.values()
392+
detached_tensor = tensor.detach().requires_grad_()
393+
detached_tensor.retain_grad()
394+
embeddings._values = detached_tensor
395+
tensors.append(tensor)
396+
detached_tensors.append(detached_tensor)
397+
self._context.embedding_tensors.append(tensors)
398+
self._context.detached_embedding_tensors.append(detached_tensors)
399+
400+
return LazyNoWait(embeddings)
373401

374402

375403
class PrefetchPipelinedForward(BaseForward):
@@ -513,22 +541,23 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:
513541
batch.record_stream(cur_stream)
514542

515543

516-
def _wait_for_event(
544+
def _wait_for_events(
517545
batch: In,
518-
device: torch.device,
519-
event: Optional[torch.Event],
546+
context: TrainPipelineContext,
547+
stream: Optional[torch.Stream],
520548
) -> None:
521549
"""
522-
Wait for event
550+
Wait for any outstanding events for a given context
523551
"""
524-
if event is not None:
525-
event.wait()
526-
cur_stream = torch.get_device_module(device).current_stream()
527552

528-
assert isinstance(
529-
batch, (torch.Tensor, Multistreamable)
530-
), f"{type(batch)} must implement Multistreamable interface"
531-
batch.record_stream(cur_stream)
553+
for event in context.events:
554+
event.wait()
555+
context.events.clear()
556+
if stream:
557+
assert isinstance(
558+
batch, (torch.Tensor, Multistreamable)
559+
), f"{type(batch)} must implement Multistreamable interface"
560+
batch.record_stream(stream)
532561

533562

534563
def _start_data_dist(
@@ -569,25 +598,19 @@ def _start_data_dist(
569598

570599

571600
def _start_embedding_lookup(
572-
pipelined_modules: List[ShardedModule],
601+
module: ShardedModule,
573602
batch: In, # not used in this function
574603
context: EmbeddingTrainPipelineContext,
575-
device: torch.device,
604+
stream: Optional[torch.Stream],
576605
) -> None:
577-
cur_stream = torch.get_device_module(device).current_stream()
578-
kjts_per_module = []
579-
for module in pipelined_modules:
580-
kjts = context.input_dist_tensors_requests[module.forward.name].wait()
581-
kjts.record_stream(cur_stream)
582-
kjts_per_module.append(kjts)
583-
584-
for module, kjts in zip(pipelined_modules, kjts_per_module):
585-
module_name = module.forward.name
586-
module_context = context.module_contexts[module.forward.name]
587-
module_context.record_stream(cur_stream)
588-
a2a_awaitable = module.compute_and_output_dist(module_context, kjts)
589-
# pyre-ignore[6]
590-
context.embedding_a2a_requests[module_name] = a2a_awaitable
606+
kjt = context.input_dist_tensors_requests[module.forward.name].wait()
607+
module_context = context.module_contexts[module.forward.name]
608+
if stream:
609+
kjt.record_stream(stream)
610+
module_context.record_stream(stream)
611+
a2a_awaitable = module.compute_and_output_dist(module_context, kjt)
612+
# pyre-ignore[6]
613+
context.embedding_a2a_requests[module.forward.name] = a2a_awaitable
591614

592615

593616
def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:

0 commit comments

Comments
 (0)