Skip to content

Commit 713137c

Browse files
babakxfacebook-github-bot
authored andcommitted
Implements a simple pipeline with PT2 compilation (#2108)
Summary: Pull Request resolved: #2108 This diff implements a simple train pipeline that uses PT2 to compile the model. Reviewed By: IvanKobzarev Differential Revision: D58110124
1 parent d7cee41 commit 713137c

File tree

3 files changed

+257
-2
lines changed

3 files changed

+257
-2
lines changed

torchrec/distributed/train_pipeline/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
StagedTrainPipeline, # noqa
1515
TrainPipeline, # noqa
1616
TrainPipelineBase, # noqa
17+
TrainPipelinePT2, # noqa
1718
TrainPipelineSparseDist, # noqa
1819
)
1920
from torchrec.distributed.train_pipeline.utils import ( # noqa

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,18 @@
1111
import unittest
1212
from dataclasses import dataclass
1313
from functools import partial
14-
from typing import cast, List, Optional, Tuple, Type
14+
from typing import cast, Dict, List, Optional, Tuple, Type
1515
from unittest.mock import MagicMock
1616

1717
import torch
1818
from hypothesis import given, settings, strategies as st, Verbosity
1919
from torch import nn, optim
20+
from torch._dynamo.testing import reduce_to_scalar_loss
2021
from torchrec.distributed import DistributedModelParallel
21-
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
22+
from torchrec.distributed.embedding_types import (
23+
EmbeddingComputeKernel,
24+
EmbeddingTableConfig,
25+
)
2226
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
2327
from torchrec.distributed.fp_embeddingbag import (
2428
FeatureProcessedEmbeddingBagCollectionSharder,
@@ -45,6 +49,7 @@
4549
PrefetchTrainPipelineSparseDist,
4650
StagedTrainPipeline,
4751
TrainPipelineBase,
52+
TrainPipelinePT2,
4853
TrainPipelineSemiSync,
4954
TrainPipelineSparseDist,
5055
)
@@ -63,10 +68,15 @@
6368
ShardingPlan,
6469
ShardingType,
6570
)
71+
from torchrec.fb.ads.modules.variable_length_embedding_arch import (
72+
VariableLengthEmbeddingArch,
73+
)
6674
from torchrec.modules.embedding_configs import DataType
6775

6876
from torchrec.optim.keyed import KeyedOptimizerWrapper
6977
from torchrec.optim.optimizers import in_backward_optimizer_filter
78+
from torchrec.pt2.utils import kjt_for_pt2_tracing
79+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
7080
from torchrec.streamable import Pipelineable
7181

7282

@@ -93,6 +103,7 @@ def __init__(self) -> None:
93103
super().__init__()
94104
self.model = nn.Linear(10, 1)
95105
self.loss_fn = nn.BCEWithLogitsLoss()
106+
self._dummy_setting: str = "dummy"
96107

97108
def forward(
98109
self, model_input: ModelInputSimple
@@ -156,6 +167,153 @@ def test_equal_to_non_pipelined(self) -> None:
156167
self.assertTrue(torch.isclose(pred_gpu.cpu(), pred))
157168

158169

170+
class TrainPipelinePT2Test(unittest.TestCase):
171+
def setUp(self) -> None:
172+
self.device = torch.device("cuda:0")
173+
torch.backends.cudnn.allow_tf32 = False
174+
torch.backends.cuda.matmul.allow_tf32 = False
175+
176+
def gen_etc_list(self, is_weighted: bool = False) -> List[EmbeddingTableConfig]:
177+
weighted_prefix = "weighted_" if is_weighted else ""
178+
179+
return [
180+
EmbeddingTableConfig(
181+
num_embeddings=256,
182+
embedding_dim=12,
183+
name=weighted_prefix + "table_0",
184+
feature_names=[weighted_prefix + "f0"],
185+
),
186+
EmbeddingTableConfig(
187+
num_embeddings=256,
188+
embedding_dim=12,
189+
name=weighted_prefix + "table_1",
190+
feature_names=[weighted_prefix + "f1"],
191+
),
192+
]
193+
194+
def gen_model(
195+
self, device: torch.device, etc_list: List[EmbeddingTableConfig]
196+
) -> nn.Module:
197+
class M_vle(torch.nn.Module):
198+
def __init__(self, vle: VariableLengthEmbeddingArch) -> None:
199+
super().__init__()
200+
self.model = vle
201+
202+
def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]:
203+
d: Dict[str, torch.Tensor] = self.model(x)
204+
return list(d.values())
205+
206+
return M_vle(
207+
VariableLengthEmbeddingArch(
208+
device=device,
209+
tables=etc_list,
210+
)
211+
)
212+
213+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
214+
@unittest.skipIf(
215+
not torch.cuda.is_available(),
216+
"Not enough GPUs, this test requires at least one GPU",
217+
)
218+
def test_equal_to_non_pipelined(self) -> None:
219+
model_cpu = TestModule()
220+
model_gpu = TestModule().to(self.device)
221+
model_gpu.load_state_dict(model_cpu.state_dict())
222+
optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01)
223+
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)
224+
data = [
225+
ModelInputSimple(
226+
float_features=torch.rand((10,)),
227+
label=torch.randint(2, (1,), dtype=torch.float32),
228+
)
229+
for b in range(5)
230+
]
231+
dataloader = iter(data)
232+
pipeline = TrainPipelinePT2(model_gpu, optimizer_gpu, self.device)
233+
234+
for batch in data[:-1]:
235+
optimizer_cpu.zero_grad()
236+
loss, pred = model_cpu(batch)
237+
loss.backward()
238+
optimizer_cpu.step()
239+
240+
pred_gpu = pipeline.progress(dataloader)
241+
242+
self.assertEqual(pred_gpu.device, self.device)
243+
self.assertTrue(torch.isclose(pred_gpu.cpu(), pred))
244+
245+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
246+
@unittest.skipIf(
247+
not torch.cuda.is_available(),
248+
"Not enough GPUs, this test requires at least one GPU",
249+
)
250+
def test_pre_compile_fn(self) -> None:
251+
model_cpu = TestModule()
252+
model_gpu = TestModule().to(self.device)
253+
model_gpu.load_state_dict(model_cpu.state_dict())
254+
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)
255+
data = [
256+
ModelInputSimple(
257+
float_features=torch.rand((10,)),
258+
label=torch.randint(2, (1,), dtype=torch.float32),
259+
)
260+
for b in range(5)
261+
]
262+
263+
def pre_compile_fn(model: nn.Module) -> None:
264+
model._dummy_setting = "dummy modified"
265+
266+
dataloader = iter(data)
267+
pipeline = TrainPipelinePT2(
268+
model_gpu, optimizer_gpu, self.device, pre_compile_fn=pre_compile_fn
269+
)
270+
self.assertEqual(model_gpu._dummy_setting, "dummy")
271+
for _ in range(len(data)):
272+
pipeline.progress(dataloader)
273+
self.assertEqual(model_gpu._dummy_setting, "dummy modified")
274+
275+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
276+
@unittest.skipIf(
277+
not torch.cuda.is_available(),
278+
"Not enough GPUs, this test requires at least one GPU",
279+
)
280+
def test_equal_to_non_pipelined_with_input_transformer(self) -> None:
281+
cpu = torch.device("cpu:0")
282+
etc_list = self.gen_etc_list()
283+
etc_list_weighted = self.gen_etc_list(is_weighted=True)
284+
285+
model_cpu = self.gen_model(cpu, etc_list)
286+
model_gpu = self.gen_model(self.device, etc_list).to(self.device)
287+
288+
_, local_model_inputs = ModelInput.generate(
289+
batch_size=10,
290+
world_size=4,
291+
num_float_features=8,
292+
tables=etc_list,
293+
weighted_tables=etc_list_weighted,
294+
variable_batch_size=False,
295+
)
296+
297+
model_gpu.load_state_dict(model_cpu.state_dict())
298+
optimizer_cpu = optim.SGD(model_cpu.model.parameters(), lr=0.01)
299+
optimizer_gpu = optim.SGD(model_gpu.model.parameters(), lr=0.01)
300+
301+
data = [i.idlist_features for i in local_model_inputs]
302+
dataloader = iter(data)
303+
pipeline = TrainPipelinePT2(
304+
model_gpu, optimizer_gpu, self.device, input_transformer=kjt_for_pt2_tracing
305+
)
306+
307+
for batch in data[:-1]:
308+
optimizer_cpu.zero_grad()
309+
loss, pred = model_cpu(batch)
310+
loss = reduce_to_scalar_loss(loss)
311+
pred_gpu = pipeline.progress(dataloader)
312+
313+
self.assertEqual(pred_gpu.device, self.device)
314+
torch.testing.assert_close(pred_gpu.cpu(), pred)
315+
316+
159317
class TrainPipelineSparseDistTest(TrainPipelineSparseDistTestBase):
160318
# pyre-fixme[56]: Pyre was not able to infer the type of argument
161319
@unittest.skipIf(

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import abc
1111
import logging
1212
from collections import deque
13+
from dataclasses import dataclass
1314
from typing import (
1415
Any,
1516
Callable,
@@ -70,6 +71,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
7071
pass
7172

7273

74+
@dataclass
75+
class TorchCompileConfig:
76+
"""
77+
Configs for torch.compile
78+
79+
fullgraph: bool = False, whether to compile the whole graph or not
80+
dynamic: bool = False, whether to use dynamic shapes or not
81+
backend: str = "inductor", which compiler to use (either inductor or aot)
82+
compile_on_iter: int = 3, compile the model on which iteration
83+
this is useful when we want to profile the first few iterations of training
84+
and then start using compiled model from iteration #3 onwards
85+
"""
86+
87+
fullgraph: bool = False
88+
dynamic: bool = False
89+
backend: str = "inductor"
90+
compile_on_iter: int = 3
91+
92+
7393
class TrainPipelineBase(TrainPipeline[In, Out]):
7494
"""
7595
This class runs training iterations using a pipeline of two stages, each as a CUDA
@@ -138,6 +158,82 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
138158
return output
139159

140160

161+
class TrainPipelinePT2(TrainPipelineBase[In, Out]):
162+
"""
163+
This pipeline uses PT2 compiler to compile the model and run it in a single stream (default)
164+
Args:
165+
model (torch.nn.Module): model to pipeline.
166+
optimizer (torch.optim.Optimizer): optimizer to use.
167+
device (torch.device): device where the model is run
168+
compile_configs (TorchCompileConfig): configs for compling the model
169+
pre_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute before compiling the model
170+
post_compile_fn (Callable[[torch.nn.Module], [None]]): Optional callable to execute after compiling the model
171+
input_transformer (Callable[[In], In]): transforms the input before passing it to the model.
172+
This is useful when we want to transform KJT parameters for PT2 tracing
173+
"""
174+
175+
def __init__(
176+
self,
177+
model: torch.nn.Module,
178+
optimizer: torch.optim.Optimizer,
179+
device: torch.device,
180+
compile_configs: Optional[TorchCompileConfig] = None,
181+
pre_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None,
182+
post_compile_fn: Optional[Callable[[torch.nn.Module], None]] = None,
183+
input_transformer: Optional[Callable[[In], In]] = None,
184+
) -> None:
185+
self._model = model
186+
self._optimizer = optimizer
187+
self._device = device
188+
self._compile_configs: TorchCompileConfig = (
189+
compile_configs or TorchCompileConfig()
190+
)
191+
self._pre_compile_fn = pre_compile_fn
192+
self._post_compile_fn = post_compile_fn
193+
self._input_transformer = input_transformer
194+
self._iter = 0
195+
self._cur_batch: Optional[In] = None
196+
197+
def progress(self, dataloader_iter: Iterator[In]) -> Out:
198+
cc = self._compile_configs
199+
200+
with record_function("## load_batch ##"):
201+
cur_batch = next(dataloader_iter)
202+
203+
if self._input_transformer:
204+
cur_batch = self._input_transformer(cur_batch)
205+
206+
with record_function("## copy_batch_to_gpu ##"):
207+
self._cur_batch = _to_device(cur_batch, self._device, non_blocking=False)
208+
209+
if self._model.training:
210+
with record_function("## zero_grad ##"):
211+
self._optimizer.zero_grad()
212+
213+
with record_function("## forward ##"):
214+
if self._iter == cc.compile_on_iter:
215+
logger.info("Compiling model...")
216+
if self._pre_compile_fn:
217+
self._pre_compile_fn(self._model)
218+
self._model.compile(
219+
fullgraph=cc.fullgraph, dynamic=cc.dynamic, backend=cc.backend
220+
)
221+
if self._post_compile_fn:
222+
self._post_compile_fn(self._model)
223+
224+
(losses, output) = self._model(self._cur_batch)
225+
self._iter += 1
226+
227+
if self._model.training:
228+
with record_function("## backward ##"):
229+
torch.sum(losses).backward()
230+
231+
with record_function("## optimizer ##"):
232+
self._optimizer.step()
233+
234+
return output
235+
236+
141237
class TrainPipelineSparseDist(TrainPipeline[In, Out]):
142238
"""
143239
This pipeline overlaps device transfer, and `ShardedModule.input_dist()` with

0 commit comments

Comments
 (0)