diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index 8c3bb4e5b..5f2569ed2 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -28,6 +28,7 @@ import torch from torch.autograd.profiler import record_function +from torchrec.distributed.comm_ops import set_use_sync_collectives from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable from torchrec.distributed.model_parallel import ShardedModule from torchrec.distributed.train_pipeline.utils import ( @@ -208,17 +209,24 @@ def __init__( self._cur_batch: Optional[In] = None def progress(self, dataloader_iter: Iterator[In]) -> Out: + if self._iter == 0: + # Turn on sync collectives for PT2 pipeline. + # To have similar logic between compiled/graph_break ranks. + set_use_sync_collectives(True) + cc = self._compile_configs with record_function("## load_batch ##"): cur_batch = next(dataloader_iter) - if self._input_transformer: - cur_batch = self._input_transformer(cur_batch) - with record_function("## copy_batch_to_gpu ##"): self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False) + # Input transformer here is used also for pt2 hints to compiler, that should happen on exact object passed to model.compile. + # Do not move it before _to_device + if self._input_transformer: + self._cur_batch = self._input_transformer(self._cur_batch) + if self._model.training: with record_function("## zero_grad ##"): self._optimizer.zero_grad()