|
28 | 28 |
|
29 | 29 | import torch
|
30 | 30 | from torch.autograd.profiler import record_function
|
| 31 | +from torchrec.distributed.comm_ops import set_use_sync_collectives |
31 | 32 | from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
|
32 | 33 | from torchrec.distributed.model_parallel import ShardedModule
|
33 | 34 | from torchrec.distributed.train_pipeline.utils import (
|
@@ -208,17 +209,24 @@ def __init__(
|
208 | 209 | self._cur_batch: Optional[In] = None
|
209 | 210 |
|
210 | 211 | def progress(self, dataloader_iter: Iterator[In]) -> Out:
|
| 212 | + if self._iter == 0: |
| 213 | + # Turn on sync collectives for PT2 pipeline. |
| 214 | + # To have similar logic between compiled/graph_break ranks. |
| 215 | + set_use_sync_collectives(True) |
| 216 | + |
211 | 217 | cc = self._compile_configs
|
212 | 218 |
|
213 | 219 | with record_function("## load_batch ##"):
|
214 | 220 | cur_batch = next(dataloader_iter)
|
215 | 221 |
|
216 |
| - if self._input_transformer: |
217 |
| - cur_batch = self._input_transformer(cur_batch) |
218 |
| - |
219 | 222 | with record_function("## copy_batch_to_gpu ##"):
|
220 | 223 | self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False)
|
221 | 224 |
|
| 225 | + # Input transformer here is used also for pt2 hints to compiler, that should happen on exact object passed to model.compile. |
| 226 | + # Do not move it before _to_device |
| 227 | + if self._input_transformer: |
| 228 | + self._cur_batch = self._input_transformer(self._cur_batch) |
| 229 | + |
222 | 230 | if self._model.training:
|
223 | 231 | with record_function("## zero_grad ##"):
|
224 | 232 | self._optimizer.zero_grad()
|
|
0 commit comments