|
19 | 19 | from hypothesis import given, settings, strategies as st, Verbosity
|
20 | 20 | from torch import nn, optim
|
21 | 21 | from torch._dynamo.testing import reduce_to_scalar_loss
|
| 22 | +from torch._dynamo.utils import counters |
22 | 23 | from torchrec.distributed import DistributedModelParallel
|
23 | 24 | from torchrec.distributed.embedding_types import EmbeddingComputeKernel
|
24 | 25 | from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
|
|
53 | 54 | TrainPipelinePT2,
|
54 | 55 | TrainPipelineSemiSync,
|
55 | 56 | TrainPipelineSparseDist,
|
| 57 | + TrainPipelineSparseDistCompAutograd, |
56 | 58 | )
|
57 | 59 | from torchrec.distributed.train_pipeline.utils import (
|
58 | 60 | DataLoadingThread,
|
@@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
|
393 | 395 | sharded_sparse_arch_pipeline.parameters(), lr=0.1
|
394 | 396 | )
|
395 | 397 |
|
396 |
| - pipeline = TrainPipelineSparseDist( |
| 398 | + pipeline = self.pipeline_class( |
397 | 399 | sharded_sparse_arch_pipeline,
|
398 | 400 | optimizer_pipeline,
|
399 | 401 | self.device,
|
@@ -441,7 +443,7 @@ def _setup_pipeline(
|
441 | 443 | dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
|
442 | 444 | lambda params: optim.SGD(params, lr=0.1),
|
443 | 445 | )
|
444 |
| - return TrainPipelineSparseDist( |
| 446 | + return self.pipeline_class( |
445 | 447 | model=distributed_model,
|
446 | 448 | optimizer=optimizer_distributed,
|
447 | 449 | device=self.device,
|
@@ -508,7 +510,7 @@ def test_equal_to_non_pipelined(
|
508 | 510 | sharded_model.state_dict(), sharded_model_pipelined.state_dict()
|
509 | 511 | )
|
510 | 512 |
|
511 |
| - pipeline = TrainPipelineSparseDist( |
| 513 | + pipeline = self.pipeline_class( |
512 | 514 | model=sharded_model_pipelined,
|
513 | 515 | optimizer=optim_pipelined,
|
514 | 516 | device=self.device,
|
@@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None:
|
621 | 623 | sharded_model.state_dict(), sharded_model_pipelined.state_dict()
|
622 | 624 | )
|
623 | 625 |
|
624 |
| - pipeline = TrainPipelineSparseDist( |
| 626 | + pipeline = self.pipeline_class( |
625 | 627 | model=sharded_model_pipelined,
|
626 | 628 | optimizer=optim_pipelined,
|
627 | 629 | device=self.device,
|
@@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None:
|
719 | 721 | sharded_model.state_dict(), sharded_model_pipelined.state_dict()
|
720 | 722 | )
|
721 | 723 |
|
722 |
| - pipeline = TrainPipelineSparseDist( |
| 724 | + pipeline = self.pipeline_class( |
723 | 725 | model=sharded_model_pipelined,
|
724 | 726 | optimizer=optim_pipelined,
|
725 | 727 | device=self.device,
|
@@ -862,7 +864,7 @@ def _check_output_equal(
|
862 | 864 | sharded_model.state_dict(), sharded_model_pipelined.state_dict()
|
863 | 865 | )
|
864 | 866 |
|
865 |
| - pipeline = TrainPipelineSparseDist( |
| 867 | + pipeline = self.pipeline_class( |
866 | 868 | model=sharded_model_pipelined,
|
867 | 869 | optimizer=optim_pipelined,
|
868 | 870 | device=self.device,
|
@@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None:
|
1116 | 1118 | model, self.sharding_type, self.kernel_type, self.fused_params
|
1117 | 1119 | )
|
1118 | 1120 |
|
1119 |
| - pipeline = TrainPipelineSparseDist( |
| 1121 | + pipeline = self.pipeline_class( |
1120 | 1122 | model=sharded_model_pipelined,
|
1121 | 1123 | optimizer=optim_pipelined,
|
1122 | 1124 | device=self.device,
|
@@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive(
|
1171 | 1173 | model, self.sharding_type, self.kernel_type, self.fused_params
|
1172 | 1174 | )
|
1173 | 1175 |
|
1174 |
| - pipeline = TrainPipelineSparseDist( |
| 1176 | + pipeline = self.pipeline_class( |
1175 | 1177 | model=sharded_model_pipelined,
|
1176 | 1178 | optimizer=optim_pipelined,
|
1177 | 1179 | device=self.device,
|
@@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None:
|
1217 | 1219 | model, self.sharding_type, self.kernel_type, self.fused_params
|
1218 | 1220 | )
|
1219 | 1221 |
|
1220 |
| - pipeline = TrainPipelineSparseDist( |
| 1222 | + pipeline = self.pipeline_class( |
1221 | 1223 | model=sharded_model_pipelined,
|
1222 | 1224 | optimizer=optim_pipelined,
|
1223 | 1225 | device=self.device,
|
@@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None:
|
1280 | 1282 | model, self.sharding_type, self.kernel_type, self.fused_params
|
1281 | 1283 | )
|
1282 | 1284 |
|
1283 |
| - pipeline = TrainPipelineSparseDist( |
| 1285 | + pipeline = self.pipeline_class( |
1284 | 1286 | model=sharded_model_pipelined,
|
1285 | 1287 | optimizer=optim_pipelined,
|
1286 | 1288 | device=self.device,
|
@@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut:
|
2100 | 2102 | self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
|
2101 | 2103 | for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
|
2102 | 2104 | torch.testing.assert_close(out, ref_out)
|
| 2105 | + |
| 2106 | + |
| 2107 | +class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest): |
| 2108 | + def setUp(self) -> None: |
| 2109 | + super().setUp() |
| 2110 | + self.pipeline_class = TrainPipelineSparseDistCompAutograd |
| 2111 | + torch._dynamo.reset() |
| 2112 | + counters["compiled_autograd"].clear() |
| 2113 | + # Compiled Autograd don't work with Anomaly Mode |
| 2114 | + torch.autograd.set_detect_anomaly(False) |
| 2115 | + |
| 2116 | + def tearDown(self) -> None: |
| 2117 | + # Every single test has two captures, one for forward and one for backward |
| 2118 | + self.assertEqual(counters["compiled_autograd"]["captures"], 2) |
| 2119 | + return super().tearDown() |
| 2120 | + |
| 2121 | + @unittest.skip("Dynamo only supports FSDP with use_orig_params=True") |
| 2122 | + # pyre-ignore[56] |
| 2123 | + @given(execute_all_batches=st.booleans()) |
| 2124 | + def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: |
| 2125 | + super().test_pipelining_fsdp_pre_trace() |
0 commit comments