11
11
import unittest
12
12
from dataclasses import dataclass
13
13
from functools import partial
14
- from typing import cast , List , Optional , Tuple , Type
14
+ from typing import cast , Dict , List , Optional , Tuple , Type
15
15
from unittest .mock import MagicMock
16
16
17
17
import torch
18
18
from hypothesis import given , settings , strategies as st , Verbosity
19
19
from torch import nn , optim
20
+ from torch ._dynamo .testing import reduce_to_scalar_loss
20
21
from torchrec .distributed import DistributedModelParallel
21
- from torchrec .distributed .embedding_types import EmbeddingComputeKernel
22
+ from torchrec .distributed .embedding_types import (
23
+ EmbeddingComputeKernel ,
24
+ EmbeddingTableConfig ,
25
+ )
22
26
from torchrec .distributed .embeddingbag import EmbeddingBagCollectionSharder
23
27
from torchrec .distributed .fp_embeddingbag import (
24
28
FeatureProcessedEmbeddingBagCollectionSharder ,
45
49
PrefetchTrainPipelineSparseDist ,
46
50
StagedTrainPipeline ,
47
51
TrainPipelineBase ,
52
+ TrainPipelinePT2 ,
48
53
TrainPipelineSemiSync ,
49
54
TrainPipelineSparseDist ,
50
55
)
63
68
ShardingPlan ,
64
69
ShardingType ,
65
70
)
71
+ from torchrec .fb .ads .modules .variable_length_embedding_arch import (
72
+ VariableLengthEmbeddingArch ,
73
+ )
66
74
from torchrec .modules .embedding_configs import DataType
67
75
68
76
from torchrec .optim .keyed import KeyedOptimizerWrapper
69
77
from torchrec .optim .optimizers import in_backward_optimizer_filter
78
+ from torchrec .pt2 .utils import kjt_for_pt2_tracing
79
+ from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
70
80
from torchrec .streamable import Pipelineable
71
81
72
82
@@ -93,6 +103,7 @@ def __init__(self) -> None:
93
103
super ().__init__ ()
94
104
self .model = nn .Linear (10 , 1 )
95
105
self .loss_fn = nn .BCEWithLogitsLoss ()
106
+ self ._dummy_setting : str = "dummy"
96
107
97
108
def forward (
98
109
self , model_input : ModelInputSimple
@@ -156,6 +167,153 @@ def test_equal_to_non_pipelined(self) -> None:
156
167
self .assertTrue (torch .isclose (pred_gpu .cpu (), pred ))
157
168
158
169
170
+ class TrainPipelinePT2Test (unittest .TestCase ):
171
+ def setUp (self ) -> None :
172
+ self .device = torch .device ("cuda:0" )
173
+ torch .backends .cudnn .allow_tf32 = False
174
+ torch .backends .cuda .matmul .allow_tf32 = False
175
+
176
+ def gen_etc_list (self , is_weighted : bool = False ) -> List [EmbeddingTableConfig ]:
177
+ weighted_prefix = "weighted_" if is_weighted else ""
178
+
179
+ return [
180
+ EmbeddingTableConfig (
181
+ num_embeddings = 256 ,
182
+ embedding_dim = 12 ,
183
+ name = weighted_prefix + "table_0" ,
184
+ feature_names = [weighted_prefix + "f0" ],
185
+ ),
186
+ EmbeddingTableConfig (
187
+ num_embeddings = 256 ,
188
+ embedding_dim = 12 ,
189
+ name = weighted_prefix + "table_1" ,
190
+ feature_names = [weighted_prefix + "f1" ],
191
+ ),
192
+ ]
193
+
194
+ def gen_model (
195
+ self , device : torch .device , etc_list : List [EmbeddingTableConfig ]
196
+ ) -> nn .Module :
197
+ class M_vle (torch .nn .Module ):
198
+ def __init__ (self , vle : VariableLengthEmbeddingArch ) -> None :
199
+ super ().__init__ ()
200
+ self .model = vle
201
+
202
+ def forward (self , x : KeyedJaggedTensor ) -> List [JaggedTensor ]:
203
+ d : Dict [str , torch .Tensor ] = self .model (x )
204
+ return list (d .values ())
205
+
206
+ return M_vle (
207
+ VariableLengthEmbeddingArch (
208
+ device = device ,
209
+ tables = etc_list ,
210
+ )
211
+ )
212
+
213
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
214
+ @unittest .skipIf (
215
+ not torch .cuda .is_available (),
216
+ "Not enough GPUs, this test requires at least one GPU" ,
217
+ )
218
+ def test_equal_to_non_pipelined (self ) -> None :
219
+ model_cpu = TestModule ()
220
+ model_gpu = TestModule ().to (self .device )
221
+ model_gpu .load_state_dict (model_cpu .state_dict ())
222
+ optimizer_cpu = optim .SGD (model_cpu .model .parameters (), lr = 0.01 )
223
+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
224
+ data = [
225
+ ModelInputSimple (
226
+ float_features = torch .rand ((10 ,)),
227
+ label = torch .randint (2 , (1 ,), dtype = torch .float32 ),
228
+ )
229
+ for b in range (5 )
230
+ ]
231
+ dataloader = iter (data )
232
+ pipeline = TrainPipelinePT2 (model_gpu , optimizer_gpu , self .device )
233
+
234
+ for batch in data [:- 1 ]:
235
+ optimizer_cpu .zero_grad ()
236
+ loss , pred = model_cpu (batch )
237
+ loss .backward ()
238
+ optimizer_cpu .step ()
239
+
240
+ pred_gpu = pipeline .progress (dataloader )
241
+
242
+ self .assertEqual (pred_gpu .device , self .device )
243
+ self .assertTrue (torch .isclose (pred_gpu .cpu (), pred ))
244
+
245
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
246
+ @unittest .skipIf (
247
+ not torch .cuda .is_available (),
248
+ "Not enough GPUs, this test requires at least one GPU" ,
249
+ )
250
+ def test_pre_compile_fn (self ) -> None :
251
+ model_cpu = TestModule ()
252
+ model_gpu = TestModule ().to (self .device )
253
+ model_gpu .load_state_dict (model_cpu .state_dict ())
254
+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
255
+ data = [
256
+ ModelInputSimple (
257
+ float_features = torch .rand ((10 ,)),
258
+ label = torch .randint (2 , (1 ,), dtype = torch .float32 ),
259
+ )
260
+ for b in range (5 )
261
+ ]
262
+
263
+ def pre_compile_fn (model : nn .Module ) -> None :
264
+ model ._dummy_setting = "dummy modified"
265
+
266
+ dataloader = iter (data )
267
+ pipeline = TrainPipelinePT2 (
268
+ model_gpu , optimizer_gpu , self .device , pre_compile_fn = pre_compile_fn
269
+ )
270
+ self .assertEqual (model_gpu ._dummy_setting , "dummy" )
271
+ for _ in range (len (data )):
272
+ pipeline .progress (dataloader )
273
+ self .assertEqual (model_gpu ._dummy_setting , "dummy modified" )
274
+
275
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
276
+ @unittest .skipIf (
277
+ not torch .cuda .is_available (),
278
+ "Not enough GPUs, this test requires at least one GPU" ,
279
+ )
280
+ def test_equal_to_non_pipelined_with_input_transformer (self ) -> None :
281
+ cpu = torch .device ("cpu:0" )
282
+ etc_list = self .gen_etc_list ()
283
+ etc_list_weighted = self .gen_etc_list (is_weighted = True )
284
+
285
+ model_cpu = self .gen_model (cpu , etc_list )
286
+ model_gpu = self .gen_model (self .device , etc_list ).to (self .device )
287
+
288
+ _ , local_model_inputs = ModelInput .generate (
289
+ batch_size = 10 ,
290
+ world_size = 4 ,
291
+ num_float_features = 8 ,
292
+ tables = etc_list ,
293
+ weighted_tables = etc_list_weighted ,
294
+ variable_batch_size = False ,
295
+ )
296
+
297
+ model_gpu .load_state_dict (model_cpu .state_dict ())
298
+ optimizer_cpu = optim .SGD (model_cpu .model .parameters (), lr = 0.01 )
299
+ optimizer_gpu = optim .SGD (model_gpu .model .parameters (), lr = 0.01 )
300
+
301
+ data = [i .idlist_features for i in local_model_inputs ]
302
+ dataloader = iter (data )
303
+ pipeline = TrainPipelinePT2 (
304
+ model_gpu , optimizer_gpu , self .device , input_transformer = kjt_for_pt2_tracing
305
+ )
306
+
307
+ for batch in data [:- 1 ]:
308
+ optimizer_cpu .zero_grad ()
309
+ loss , pred = model_cpu (batch )
310
+ loss = reduce_to_scalar_loss (loss )
311
+ pred_gpu = pipeline .progress (dataloader )
312
+
313
+ self .assertEqual (pred_gpu .device , self .device )
314
+ torch .testing .assert_close (pred_gpu .cpu (), pred )
315
+
316
+
159
317
class TrainPipelineSparseDistTest (TrainPipelineSparseDistTestBase ):
160
318
# pyre-fixme[56]: Pyre was not able to infer the type of argument
161
319
@unittest .skipIf (
0 commit comments