|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import abc
|
| 11 | +import contextlib |
11 | 12 | import logging
|
12 | 13 | from collections import deque
|
13 | 14 | from dataclasses import dataclass
|
14 | 15 | from typing import (
|
15 | 16 | Any,
|
16 | 17 | Callable,
|
17 | 18 | cast,
|
| 19 | + ContextManager, |
18 | 20 | Deque,
|
19 | 21 | Dict,
|
20 | 22 | Generic,
|
|
28 | 30 |
|
29 | 31 | import torch
|
30 | 32 | from torch.autograd.profiler import record_function
|
| 33 | +from torchrec.distributed.comm_ops import set_use_sync_collectives |
31 | 34 | from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
|
32 | 35 | from torchrec.distributed.model_parallel import ShardedModule
|
33 | 36 | from torchrec.distributed.train_pipeline.utils import (
|
@@ -1506,3 +1509,63 @@ def progress(
|
1506 | 1509 | return self.progress(dataloader_iter)
|
1507 | 1510 |
|
1508 | 1511 | return out
|
| 1512 | + |
| 1513 | + |
| 1514 | +class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]): |
| 1515 | + """ |
| 1516 | + This pipeline clone the TrainPipelineSparseDist, but execute the progress |
| 1517 | + method within compiled autograd context. |
| 1518 | + """ |
| 1519 | + |
| 1520 | + def __init__( |
| 1521 | + self, |
| 1522 | + model: torch.nn.Module, |
| 1523 | + optimizer: torch.optim.Optimizer, |
| 1524 | + device: torch.device, |
| 1525 | + execute_all_batches: bool = True, |
| 1526 | + apply_jit: bool = False, |
| 1527 | + context_type: Type[TrainPipelineContext] = TrainPipelineContext, |
| 1528 | + pipeline_preproc: bool = False, |
| 1529 | + custom_model_fwd: Optional[ |
| 1530 | + Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] |
| 1531 | + ] = None, |
| 1532 | + ) -> None: |
| 1533 | + set_use_sync_collectives(True) |
| 1534 | + super().__init__( |
| 1535 | + model, |
| 1536 | + optimizer, |
| 1537 | + device, |
| 1538 | + execute_all_batches, |
| 1539 | + apply_jit, |
| 1540 | + context_type, |
| 1541 | + pipeline_preproc, |
| 1542 | + custom_model_fwd, |
| 1543 | + ) |
| 1544 | + |
| 1545 | + @staticmethod |
| 1546 | + def get_compiled_autograd_ctx( |
| 1547 | + model: torch.nn.Module, |
| 1548 | + ) -> ContextManager: |
| 1549 | + compiled_autograd = ( |
| 1550 | + hasattr(model, "_compiled_autograd") and model._compiled_autograd |
| 1551 | + ) |
| 1552 | + |
| 1553 | + model._compiled_autograd_options = { |
| 1554 | + "backend": "inductor", |
| 1555 | + "dynamic": True, |
| 1556 | + "fullgraph": True, |
| 1557 | + } |
| 1558 | + torch._dynamo.config.optimize_ddp = "python_reducer" |
| 1559 | + return ( |
| 1560 | + torch._dynamo.compiled_autograd.enable( |
| 1561 | + torch.compile(**model._compiled_autograd_options) |
| 1562 | + ) |
| 1563 | + if compiled_autograd |
| 1564 | + else contextlib.nullcontext() |
| 1565 | + ) |
| 1566 | + |
| 1567 | + def progress(self, dataloader_iter: Iterator[In]) -> Out: |
| 1568 | + self._model._compiled_autograd = True |
| 1569 | + |
| 1570 | + with self.get_compiled_autograd_ctx(self._model): |
| 1571 | + return super().progress(dataloader_iter) |
0 commit comments