Skip to content

Commit 3d46954

Browse files
committed
- Add CompiledAutograd pipeline
Summary: Add new pipeline for CompiledAutograd development. Differential Revision: D61403499
1 parent 12d31d6 commit 3d46954

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TrainPipelineBase, # noqa
1717
TrainPipelinePT2, # noqa
1818
TrainPipelineSparseDist, # noqa
19+
TrainPipelineSparseDistCompAutograd, # noqa
1920
)
2021
from torchrec.distributed.train_pipeline.utils import ( # noqa
2122
_override_input_dist_forwards, # noqa

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
# pyre-strict
99

1010
import abc
11+
import contextlib
1112
import logging
1213
from collections import deque
1314
from dataclasses import dataclass
1415
from typing import (
1516
Any,
1617
Callable,
1718
cast,
19+
ContextManager,
1820
Deque,
1921
Dict,
2022
Generic,
@@ -28,6 +30,7 @@
2830

2931
import torch
3032
from torch.autograd.profiler import record_function
33+
from torchrec.distributed.comm_ops import set_use_sync_collectives
3134
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
3235
from torchrec.distributed.model_parallel import ShardedModule
3336
from torchrec.distributed.train_pipeline.utils import (
@@ -1506,3 +1509,63 @@ def progress(
15061509
return self.progress(dataloader_iter)
15071510

15081511
return out
1512+
1513+
1514+
class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]):
1515+
"""
1516+
This pipeline clone the TrainPipelineSparseDist, but execute the progress
1517+
method within compiled autograd context.
1518+
"""
1519+
1520+
def __init__(
1521+
self,
1522+
model: torch.nn.Module,
1523+
optimizer: torch.optim.Optimizer,
1524+
device: torch.device,
1525+
execute_all_batches: bool = True,
1526+
apply_jit: bool = False,
1527+
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
1528+
pipeline_preproc: bool = False,
1529+
custom_model_fwd: Optional[
1530+
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
1531+
] = None,
1532+
) -> None:
1533+
set_use_sync_collectives(True)
1534+
super().__init__(
1535+
model,
1536+
optimizer,
1537+
device,
1538+
execute_all_batches,
1539+
apply_jit,
1540+
context_type,
1541+
pipeline_preproc,
1542+
custom_model_fwd,
1543+
)
1544+
1545+
@staticmethod
1546+
def get_compiled_autograd_ctx(
1547+
model: torch.nn.Module,
1548+
) -> ContextManager:
1549+
compiled_autograd = (
1550+
hasattr(model, "_compiled_autograd") and model._compiled_autograd
1551+
)
1552+
1553+
model._compiled_autograd_options = {
1554+
"backend": "inductor",
1555+
"dynamic": True,
1556+
"fullgraph": True,
1557+
}
1558+
torch._dynamo.config.optimize_ddp = "python_reducer"
1559+
return (
1560+
torch._dynamo.compiled_autograd.enable(
1561+
torch.compile(**model._compiled_autograd_options)
1562+
)
1563+
if compiled_autograd
1564+
else contextlib.nullcontext()
1565+
)
1566+
1567+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
1568+
self._model._compiled_autograd = True
1569+
1570+
with self.get_compiled_autograd_ctx(self._model):
1571+
return super().progress(dataloader_iter)

0 commit comments

Comments
 (0)