Skip to content

Commit cf62594

Browse files
Ivan Kobzarevfacebook-github-bot
Ivan Kobzarev
authored andcommitted
Pt2Pipeline input_transform just before model.compile (#2141)
Summary: Pull Request resolved: #2141 input_transformer is used to do pt2 hints like torch._dynamo.mark_dynamic/mark_unbacked that should happen exactly on the Tensor objects passed to model.compile Before that it was called before _to_device() that will produce new Tensor objects. Reviewed By: TroyGarden Differential Revision: D58825466 fbshipit-source-id: cd05595d03b5f7b05eabd3fbb5e2c139dcc962b6
1 parent 267c18a commit cf62594

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import torch
3030
from torch.autograd.profiler import record_function
31+
from torchrec.distributed.comm_ops import set_use_sync_collectives
3132
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
3233
from torchrec.distributed.model_parallel import ShardedModule
3334
from torchrec.distributed.train_pipeline.utils import (
@@ -208,17 +209,24 @@ def __init__(
208209
self._cur_batch: Optional[In] = None
209210

210211
def progress(self, dataloader_iter: Iterator[In]) -> Out:
212+
if self._iter == 0:
213+
# Turn on sync collectives for PT2 pipeline.
214+
# To have similar logic between compiled/graph_break ranks.
215+
set_use_sync_collectives(True)
216+
211217
cc = self._compile_configs
212218

213219
with record_function("## load_batch ##"):
214220
cur_batch = next(dataloader_iter)
215221

216-
if self._input_transformer:
217-
cur_batch = self._input_transformer(cur_batch)
218-
219222
with record_function("## copy_batch_to_gpu ##"):
220223
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False)
221224

225+
# Input transformer here is used also for pt2 hints to compiler, that should happen on exact object passed to model.compile.
226+
# Do not move it before _to_device
227+
if self._input_transformer:
228+
self._cur_batch = self._input_transformer(self._cur_batch)
229+
222230
if self._model.training:
223231
with record_function("## zero_grad ##"):
224232
self._optimizer.zero_grad()

0 commit comments

Comments
 (0)