Skip to content

Overlap comms on backward pass #2117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 103 additions & 45 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
_start_embedding_lookup,
_to_device,
_wait_for_batch,
_wait_for_event,
_wait_for_events,
DataLoadingThread,
EmbeddingPipelinedForward,
EmbeddingTrainPipelineContext,
Expand Down Expand Up @@ -590,8 +590,6 @@ def start_sparse_data_dist(
return
with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
if context.event is not None:
context.event.wait()
_wait_for_batch(batch, self._memcpy_stream)

original_contexts = [p.get_context() for p in self._pipelined_preprocs]
Expand Down Expand Up @@ -737,11 +735,8 @@ def __init__(
if device.type in ["cuda", "mtia"]
else None
)
self._bwd_sync_stream: Optional[torch.Stream] = (
(torch.get_device_module(self._device).Stream(priority=0))
if device.type in ["cuda", "mtia"]
else None
)
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
self._embedding_even_streams: List[Optional[torch.Stream]] = []
self._gradients: Dict[str, torch.Tensor] = {}

def _grad_swap(self) -> None:
Expand All @@ -751,6 +746,29 @@ def _grad_swap(self) -> None:
self._gradients[name] = param.grad.clone()
param.grad = grad

def _init_embedding_streams(self) -> None:

for _ in self._pipelined_modules:
self._embedding_odd_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)
self._embedding_even_streams.append(
(torch.get_device_module(self._device).Stream(priority=0))
if self._device.type in ["cuda", "mtia"]
else None
)

def _validate_optimizer(self) -> None:
for pipelined_module in self._pipelined_modules:
pipelined_params = set(pipelined_module.parameters())
for group in self._optimizer.param_groups:
if not set(group["params"]).isdisjoint(pipelined_params):
logger.warning(
f"SemiSync pipelined {type(pipelined_module)} and optimizer share parameters"
)

def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pipeline is already filled
if len(self.batches) >= 3:
Expand All @@ -770,7 +788,9 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
# pyre-ignore [6]
EmbeddingPipelinedForward,
)
self._init_embedding_streams()
self.wait_sparse_data_dist(self.contexts[0])
self._validate_optimizer()
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[0], self.contexts[0])

Expand Down Expand Up @@ -824,26 +844,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.wait_sparse_data_dist(self.contexts[2])

if self._model.training:
# backward would put an implicit sync point in stream called from, ideally
# this would different from optimizer so it could start earilier, but currently not safe to do so.
with self._stream_context(self._overarch_stream):
with record_function(f"## backward {self.contexts[0].index} ##"):
torch.sum(losses, dim=0).backward()

with self._stream_context(self._overarch_stream):
with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
self._optimizer.zero_grad()
with record_function(f"## backward {self.contexts[0].index} ##"):
torch.sum(losses, dim=0).backward()
# pyre-ignore [6]
self.embedding_backward(self.contexts[0])

with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
if self.is_semi_sync() and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
self._optimizer.zero_grad()

if len(self.batches) >= 2 and not self.is_semi_sync():
torch.cuda.synchronize() # needed to avoid race condition
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])

Expand All @@ -854,10 +873,29 @@ def _mlp_forward(
self, batch: In, context: TrainPipelineContext
) -> Tuple[torch.Tensor, Out]:
with record_function(f"## forward {context.index} ##"):
with self._stream_context(self._overarch_stream):
_wait_for_event(batch, self._device, context.event)
context.event = None
return cast(Tuple[torch.Tensor, Out], self._model(batch))
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
return cast(Tuple[torch.Tensor, Out], self._model(batch))

def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
default_stream = torch.get_device_module(self._device).current_stream()
streams = (
self._embedding_even_streams
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams
)
for stream, emb_tensors, detached_emb_tensors in zip(
streams,
context.embedding_tensors,
context.detached_embedding_tensors,
):
with self._stream_context(stream):
grads = [tensor.grad for tensor in detached_emb_tensors]
if stream:
stream.wait_stream(default_stream)
# pyre-ignore
torch.autograd.backward(emb_tensors, grads)

def copy_batch_to_gpu(
self,
Expand All @@ -870,8 +908,9 @@ def copy_batch_to_gpu(
if batch is not None:
batch = _to_device(batch, self._device, non_blocking=True)
context = self._create_context()
context.event = torch.get_device_module(self._device).Event()
context.event.record()
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)
return batch, context

def start_sparse_data_dist(
Expand All @@ -882,9 +921,25 @@ def start_sparse_data_dist(
"""
Waits for batch to finish getting copied to GPU, then starts the input dist. This is Event based version.
"""
super().start_sparse_data_dist(batch, context)
context.event = torch.get_device_module(self._device).Event()
context.event.record()
if batch is None:
return

# Temporarily set context for next iter to populate cache
original_contexts = [p.get_context() for p in self._pipelined_preprocs]
for preproc_mod in self._pipelined_preprocs:
preproc_mod.set_context(context)

with record_function(f"## start_sparse_data_dist {context.index} ##"):
with self._stream_context(self._data_dist_stream):
_wait_for_events(batch, context, self._data_dist_stream)
_start_data_dist(self._pipelined_modules, batch, context)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)

# Restore context for model forward
for module, context in zip(self._pipelined_preprocs, original_contexts):
module.set_context(context)

def start_embedding_lookup(
self,
Expand All @@ -897,17 +952,20 @@ def start_embedding_lookup(
if batch is None:
return
with record_function(f"## start_embedding_lookup {context.index} ##"):
with self._stream_context(
self._embedding_even_stream
if cast(int, context.index) % 2 == 0
else self._embedding_odd_stream
):
_wait_for_event(batch, self._device, context.event)
_start_embedding_lookup(
self._pipelined_modules, batch, context, self._device
_wait_for_events(
batch, context, torch.get_device_module(self._device).current_stream()
)
for i, module in enumerate(self._pipelined_modules):
stream = (
self._embedding_even_streams[i]
if cast(int, context.index) % 2 == 0
else self._embedding_odd_streams[i]
)
context.event = torch.get_device_module(self._device).Event()
context.event.record()
with self._stream_context(stream):
_start_embedding_lookup(module, batch, context, stream)
event = torch.get_device_module(self._device).Event()
event.record()
context.events.append(event)


class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):
Expand Down
89 changes: 56 additions & 33 deletions torchrec/distributed/train_pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
)
from torchrec.distributed.model_parallel import DistributedModelParallel, ShardedModule

from torchrec.distributed.types import Awaitable
from torchrec.distributed.types import Awaitable, LazyNoWait

from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Multistreamable, Pipelineable

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,10 +95,8 @@ class TrainPipelineContext:
fused_splits_awaitables: List[Tuple[List[str], FusedKJTListSplitsAwaitable]] = (
field(default_factory=list)
)
event: Optional[torch.Event] = None

events: List[torch.Event] = field(default_factory=list)
preproc_fwd_results: Dict[str, Any] = field(default_factory=dict)

index: Optional[int] = None
version: int = (
0 # 1 is current version, 0 is deprecated but supported for backward compatibility
Expand All @@ -116,6 +114,8 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
@dataclass
class EmbeddingTrainPipelineContext(TrainPipelineContext):
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)


@dataclass
Expand Down Expand Up @@ -369,7 +369,35 @@ def __call__(self, *input, **kwargs) -> Awaitable:
)
cur_stream = torch.get_device_module(self._device).current_stream()
ctx.record_stream(cur_stream)
return self._context.embedding_a2a_requests.pop(self._name)
awaitable = self._context.embedding_a2a_requests.pop(self._name)
embeddings = awaitable.wait() # trigger awaitable manually for type checking
tensors = []
detached_tensors = []
if isinstance(embeddings, Dict):
for jt in embeddings.values():
assert isinstance(jt, JaggedTensor)
tensor = jt.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
jt._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
# pyre-ignore [16]
self._context.embedding_tensors.append(tensors)
# pyre-ignore [16]
self._context.detached_embedding_tensors.append(detached_tensors)
else:
assert isinstance(embeddings, KeyedTensor)
tensor = embeddings.values()
detached_tensor = tensor.detach().requires_grad_()
detached_tensor.retain_grad()
embeddings._values = detached_tensor
tensors.append(tensor)
detached_tensors.append(detached_tensor)
self._context.embedding_tensors.append(tensors)
self._context.detached_embedding_tensors.append(detached_tensors)

return LazyNoWait(embeddings)


class PrefetchPipelinedForward(BaseForward):
Expand Down Expand Up @@ -513,22 +541,23 @@ def _wait_for_batch(batch: In, stream: Optional[torch.Stream]) -> None:
batch.record_stream(cur_stream)


def _wait_for_event(
def _wait_for_events(
batch: In,
device: torch.device,
event: Optional[torch.Event],
context: TrainPipelineContext,
stream: Optional[torch.Stream],
) -> None:
"""
Wait for event
Wait for any outstanding events for a given context
"""
if event is not None:
event.wait()
cur_stream = torch.get_device_module(device).current_stream()

assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(cur_stream)
for event in context.events:
event.wait()
context.events.clear()
if stream:
assert isinstance(
batch, (torch.Tensor, Multistreamable)
), f"{type(batch)} must implement Multistreamable interface"
batch.record_stream(stream)


def _start_data_dist(
Expand Down Expand Up @@ -569,25 +598,19 @@ def _start_data_dist(


def _start_embedding_lookup(
pipelined_modules: List[ShardedModule],
module: ShardedModule,
batch: In, # not used in this function
context: EmbeddingTrainPipelineContext,
device: torch.device,
stream: Optional[torch.Stream],
) -> None:
cur_stream = torch.get_device_module(device).current_stream()
kjts_per_module = []
for module in pipelined_modules:
kjts = context.input_dist_tensors_requests[module.forward.name].wait()
kjts.record_stream(cur_stream)
kjts_per_module.append(kjts)

for module, kjts in zip(pipelined_modules, kjts_per_module):
module_name = module.forward.name
module_context = context.module_contexts[module.forward.name]
module_context.record_stream(cur_stream)
a2a_awaitable = module.compute_and_output_dist(module_context, kjts)
# pyre-ignore[6]
context.embedding_a2a_requests[module_name] = a2a_awaitable
kjt = context.input_dist_tensors_requests[module.forward.name].wait()
module_context = context.module_contexts[module.forward.name]
if stream:
kjt.record_stream(stream)
module_context.record_stream(stream)
a2a_awaitable = module.compute_and_output_dist(module_context, kjt)
# pyre-ignore[6]
context.embedding_a2a_requests[module.forward.name] = a2a_awaitable


def _fuse_input_dist_splits(context: TrainPipelineContext) -> None:
Expand Down
Loading