Skip to content

Commit a144233

Browse files
Ivan Kobzarevfacebook-github-bot
authored andcommitted
Pt2Pipeline input_transform just before model.compile (#2141)
Summary: Pull Request resolved: #2141 input_transformer is used to do pt2 hints like torch._dynamo.mark_dynamic/mark_unbacked that should happen exactly on the Tensor objects passed to model.compile Before that it was called before _to_device() that will produce new Tensor objects. Reviewed By: TroyGarden Differential Revision: D58825466
1 parent ac33f23 commit a144233

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,14 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
203203
with record_function("## load_batch ##"):
204204
cur_batch = next(dataloader_iter)
205205

206-
if self._input_transformer:
207-
cur_batch = self._input_transformer(cur_batch)
208-
209206
with record_function("## copy_batch_to_gpu ##"):
210207
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False)
211208

209+
# Input transformer here is used also for pt2 hints to compiler, that should happen on exact object passed to model.compile.
210+
# Do not move it before _to_device
211+
if self._input_transformer:
212+
self._cur_batch = self._input_transformer(self._cur_batch)
213+
212214
if self._model.training:
213215
with record_function("## zero_grad ##"):
214216
self._optimizer.zero_grad()

0 commit comments

Comments
 (0)