Skip to content

Commit f7e444d

Browse files
flaviotruzzifacebook-github-bot
authored andcommitted
- Add CompiledAutograd pipeline (#2310)
Summary: Pull Request resolved: #2310 Add new pipeline for CompiledAutograd development. Reviewed By: dstaay-fb, xmfan, yf225 Differential Revision: D61403499 fbshipit-source-id: 7bf0720e0c1078815315278fffd79c2d7470882f
1 parent 9418355 commit f7e444d

File tree

5 files changed

+128
-11
lines changed

5 files changed

+128
-11
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ def purge(self) -> None:
355355

356356

357357
class CommOpGradientScaling(torch.autograd.Function):
358+
# user override: inline autograd.Function is safe to trace since only tensor mutations / no global state
359+
_compiled_autograd_should_lift = False
360+
358361
@staticmethod
359362
# pyre-ignore
360363
def forward(

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
TrainPipelineBase, # noqa
1717
TrainPipelinePT2, # noqa
1818
TrainPipelineSparseDist, # noqa
19+
TrainPipelineSparseDistCompAutograd, # noqa
1920
)
2021
from torchrec.distributed.train_pipeline.utils import ( # noqa
2122
_override_input_dist_forwards, # noqa

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from hypothesis import given, settings, strategies as st, Verbosity
2020
from torch import nn, optim
2121
from torch._dynamo.testing import reduce_to_scalar_loss
22+
from torch._dynamo.utils import counters
2223
from torchrec.distributed import DistributedModelParallel
2324
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2425
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
@@ -53,6 +54,7 @@
5354
TrainPipelinePT2,
5455
TrainPipelineSemiSync,
5556
TrainPipelineSparseDist,
57+
TrainPipelineSparseDistCompAutograd,
5658
)
5759
from torchrec.distributed.train_pipeline.utils import (
5860
DataLoadingThread,
@@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
393395
sharded_sparse_arch_pipeline.parameters(), lr=0.1
394396
)
395397

396-
pipeline = TrainPipelineSparseDist(
398+
pipeline = self.pipeline_class(
397399
sharded_sparse_arch_pipeline,
398400
optimizer_pipeline,
399401
self.device,
@@ -441,7 +443,7 @@ def _setup_pipeline(
441443
dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
442444
lambda params: optim.SGD(params, lr=0.1),
443445
)
444-
return TrainPipelineSparseDist(
446+
return self.pipeline_class(
445447
model=distributed_model,
446448
optimizer=optimizer_distributed,
447449
device=self.device,
@@ -508,7 +510,7 @@ def test_equal_to_non_pipelined(
508510
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
509511
)
510512

511-
pipeline = TrainPipelineSparseDist(
513+
pipeline = self.pipeline_class(
512514
model=sharded_model_pipelined,
513515
optimizer=optim_pipelined,
514516
device=self.device,
@@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None:
621623
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
622624
)
623625

624-
pipeline = TrainPipelineSparseDist(
626+
pipeline = self.pipeline_class(
625627
model=sharded_model_pipelined,
626628
optimizer=optim_pipelined,
627629
device=self.device,
@@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None:
719721
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
720722
)
721723

722-
pipeline = TrainPipelineSparseDist(
724+
pipeline = self.pipeline_class(
723725
model=sharded_model_pipelined,
724726
optimizer=optim_pipelined,
725727
device=self.device,
@@ -862,7 +864,7 @@ def _check_output_equal(
862864
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
863865
)
864866

865-
pipeline = TrainPipelineSparseDist(
867+
pipeline = self.pipeline_class(
866868
model=sharded_model_pipelined,
867869
optimizer=optim_pipelined,
868870
device=self.device,
@@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None:
11161118
model, self.sharding_type, self.kernel_type, self.fused_params
11171119
)
11181120

1119-
pipeline = TrainPipelineSparseDist(
1121+
pipeline = self.pipeline_class(
11201122
model=sharded_model_pipelined,
11211123
optimizer=optim_pipelined,
11221124
device=self.device,
@@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive(
11711173
model, self.sharding_type, self.kernel_type, self.fused_params
11721174
)
11731175

1174-
pipeline = TrainPipelineSparseDist(
1176+
pipeline = self.pipeline_class(
11751177
model=sharded_model_pipelined,
11761178
optimizer=optim_pipelined,
11771179
device=self.device,
@@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None:
12171219
model, self.sharding_type, self.kernel_type, self.fused_params
12181220
)
12191221

1220-
pipeline = TrainPipelineSparseDist(
1222+
pipeline = self.pipeline_class(
12211223
model=sharded_model_pipelined,
12221224
optimizer=optim_pipelined,
12231225
device=self.device,
@@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None:
12801282
model, self.sharding_type, self.kernel_type, self.fused_params
12811283
)
12821284

1283-
pipeline = TrainPipelineSparseDist(
1285+
pipeline = self.pipeline_class(
12841286
model=sharded_model_pipelined,
12851287
optimizer=optim_pipelined,
12861288
device=self.device,
@@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut:
21002102
self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
21012103
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
21022104
torch.testing.assert_close(out, ref_out)
2105+
2106+
2107+
class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest):
2108+
def setUp(self) -> None:
2109+
super().setUp()
2110+
self.pipeline_class = TrainPipelineSparseDistCompAutograd
2111+
torch._dynamo.reset()
2112+
counters["compiled_autograd"].clear()
2113+
# Compiled Autograd don't work with Anomaly Mode
2114+
torch.autograd.set_detect_anomaly(False)
2115+
2116+
def tearDown(self) -> None:
2117+
# Every single test has two captures, one for forward and one for backward
2118+
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
2119+
return super().tearDown()
2120+
2121+
@unittest.skip("Dynamo only supports FSDP with use_orig_params=True")
2122+
# pyre-ignore[56]
2123+
@given(execute_all_batches=st.booleans())
2124+
def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None:
2125+
super().test_pipelining_fsdp_pre_trace()

torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
TestEBCSharder,
2424
TestSparseNN,
2525
)
26+
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
2627
from torchrec.distributed.types import ModuleSharder, ShardingEnv
2728
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
2829
from torchrec.test_utils import get_free_port, init_distributed_single_host
@@ -59,6 +60,7 @@ def setUp(self) -> None:
5960
]
6061

6162
self.device = torch.device("cuda:0")
63+
self.pipeline_class = TrainPipelineSparseDist
6264

6365
def tearDown(self) -> None:
6466
super().tearDown()

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@
88
# pyre-strict
99

1010
import abc
11+
import contextlib
1112
import logging
1213
from collections import deque
14+
from contextlib import contextmanager
1315
from dataclasses import dataclass
1416
from typing import (
1517
Any,
1618
Callable,
1719
cast,
20+
ContextManager,
1821
Deque,
1922
Dict,
2023
Generic,
@@ -27,6 +30,7 @@
2730
)
2831

2932
import torch
33+
import torchrec.distributed.comm_ops
3034
from torch.autograd.profiler import record_function
3135
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
3236
from torchrec.distributed.model_parallel import ShardedModule
@@ -59,7 +63,6 @@
5963
from torchrec.pt2.checks import is_torchdynamo_compiling
6064
from torchrec.pt2.utils import default_pipeline_input_transformer
6165
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
62-
from torchrec.streamable import Multistreamable
6366

6467
logger: logging.Logger = logging.getLogger(__name__)
6568

@@ -1506,3 +1509,88 @@ def progress(
15061509
return self.progress(dataloader_iter)
15071510

15081511
return out
1512+
1513+
1514+
class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]):
1515+
"""
1516+
This pipeline clone the TrainPipelineSparseDist, but execute the progress
1517+
method within compiled autograd context.
1518+
"""
1519+
1520+
def __init__(
1521+
self,
1522+
model: torch.nn.Module,
1523+
optimizer: torch.optim.Optimizer,
1524+
device: torch.device,
1525+
execute_all_batches: bool = True,
1526+
apply_jit: bool = False,
1527+
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
1528+
pipeline_preproc: bool = False,
1529+
custom_model_fwd: Optional[
1530+
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
1531+
] = None,
1532+
) -> None:
1533+
super().__init__(
1534+
model,
1535+
optimizer,
1536+
device,
1537+
execute_all_batches,
1538+
apply_jit,
1539+
context_type,
1540+
pipeline_preproc,
1541+
custom_model_fwd,
1542+
)
1543+
1544+
# it will check this path on model to inject configuration other than
1545+
# the default one.
1546+
self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr(
1547+
model,
1548+
"_compiled_autograd_options",
1549+
{
1550+
"backend": "inductor",
1551+
"dynamic": True,
1552+
"fullgraph": True,
1553+
},
1554+
)
1555+
1556+
torch._dynamo.config.optimize_ddp = "python_reducer"
1557+
torch._dynamo.config.inline_inbuilt_nn_modules = True
1558+
torch._dynamo.config.skip_fsdp_hooks = False
1559+
torch._functorch.config.recompute_views = True
1560+
torch._functorch.config.cse = False
1561+
torch._inductor.config.reorder_for_compute_comm_overlap = True
1562+
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
1563+
"sink_waits",
1564+
"raise_comms",
1565+
"reorder_compute_for_overlap",
1566+
]
1567+
self.initialized = False
1568+
1569+
def get_compiled_autograd_ctx(
1570+
self,
1571+
) -> ContextManager:
1572+
# this allows for pipelining
1573+
# to avoid doing a sum on None
1574+
# when the pipeline is empty
1575+
if not self.initialized:
1576+
self.initialized = True
1577+
return contextlib.nullcontext()
1578+
1579+
return torch._dynamo.compiled_autograd.enable(
1580+
# pyre-ignore
1581+
torch.compile(**self.compiled_autograd_options)
1582+
)
1583+
1584+
@contextmanager
1585+
def sync_collectives_ctx(self) -> Iterator[None]:
1586+
try:
1587+
if is_torchdynamo_compiling():
1588+
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
1589+
yield
1590+
finally:
1591+
torchrec.distributed.comm_ops.set_use_sync_collectives(False)
1592+
1593+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
1594+
1595+
with self.get_compiled_autograd_ctx(), self.sync_collectives_ctx():
1596+
return super().progress(dataloader_iter)

0 commit comments

Comments
 (0)