19
19
import torchrec
20
20
import torchrec .pt2 .checks
21
21
from hypothesis import given , settings , strategies as st , Verbosity
22
+ from torch ._dynamo .testing import reduce_to_scalar_loss
22
23
from torchrec .distributed .embedding import EmbeddingCollectionSharder
23
24
from torchrec .distributed .embedding_types import EmbeddingComputeKernel
24
25
from torchrec .distributed .fbgemm_qcomm_codec import QCommsConfig
56
57
from torchrec .pt2 .utils import kjt_for_pt2_tracing
57
58
from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor , KeyedTensor
58
59
60
+
59
61
try :
60
62
torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops" )
61
63
torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu" )
@@ -139,17 +141,22 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
139
141
140
142
141
143
def _gen_model (test_model_type : _ModelType , mi : TestModelInfo ) -> torch .nn .Module :
144
+ emb_dim : int = max (t .embedding_dim for t in mi .tables )
142
145
if test_model_type == _ModelType .EBC :
143
146
144
147
class M_ebc (torch .nn .Module ):
145
148
def __init__ (self , ebc : EmbeddingBagCollection ) -> None :
146
149
super ().__init__ ()
147
150
self ._ebc = ebc
151
+ self ._linear = torch .nn .Linear (
152
+ mi .num_float_features , emb_dim , device = mi .dense_device
153
+ )
148
154
149
- def forward (self , x : KeyedJaggedTensor ) -> torch .Tensor :
155
+ def forward (self , x : KeyedJaggedTensor , y : torch . Tensor ) -> torch .Tensor :
150
156
kt : KeyedTensor = self ._ebc (x )
151
157
v = kt .values ()
152
- return torch .sigmoid (torch .mean (v , dim = 1 ))
158
+ y = self ._linear (y )
159
+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
153
160
154
161
return M_ebc (
155
162
EmbeddingBagCollection (
@@ -164,10 +171,15 @@ class M_fpebc(torch.nn.Module):
164
171
def __init__ (self , fpebc : FeatureProcessedEmbeddingBagCollection ) -> None :
165
172
super ().__init__ ()
166
173
self ._fpebc = fpebc
174
+ self ._linear = torch .nn .Linear (
175
+ mi .num_float_features , emb_dim , device = mi .dense_device
176
+ )
167
177
168
- def forward (self , x : KeyedJaggedTensor ) -> torch .Tensor :
178
+ def forward (self , x : KeyedJaggedTensor , y : torch . Tensor ) -> torch .Tensor :
169
179
kt : KeyedTensor = self ._fpebc (x )
170
- return kt .values ()
180
+ v = kt .values ()
181
+ y = self ._linear (y )
182
+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
171
183
172
184
return M_fpebc (
173
185
FeatureProcessedEmbeddingBagCollection (
@@ -187,9 +199,13 @@ def __init__(self, ec: EmbeddingCollection) -> None:
187
199
super ().__init__ ()
188
200
self ._ec = ec
189
201
190
- def forward (self , x : KeyedJaggedTensor ) -> List [JaggedTensor ]:
202
+ def forward (
203
+ self , x : KeyedJaggedTensor , y : torch .Tensor
204
+ ) -> List [JaggedTensor ]:
191
205
d : Dict [str , JaggedTensor ] = self ._ec (x )
192
- return list (d .values ())
206
+ v = torch .stack (d .values (), dim = 0 ).sum (dim = 0 )
207
+ y = self ._linear (y )
208
+ return torch .mul (torch .mean (v , dim = 1 ), torch .mean (y , dim = 1 ))
193
209
194
210
return M_ec (
195
211
EmbeddingCollection (
@@ -307,6 +323,7 @@ def _test_compile_rank_fn(
307
323
# pyre-ignore
308
324
sharders = sharders ,
309
325
device = device ,
326
+ init_data_parallel = False ,
310
327
)
311
328
312
329
if input_type == _InputType .VARIABLE_BATCH :
@@ -336,19 +353,27 @@ def _test_compile_rank_fn(
336
353
local_model_input = local_model_inputs [0 ].to (device )
337
354
338
355
kjt = local_model_input .idlist_features
356
+ ff = local_model_input .float_features
357
+ ff .requires_grad = True
339
358
kjt_ft = kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb )
340
359
360
+ compile_input_ff = ff .clone ().detach ()
361
+
341
362
torchrec .distributed .comm_ops .set_use_sync_collectives (True )
342
363
torchrec .pt2 .checks .set_use_torchdynamo_compiling_path (True )
343
364
344
365
dmp .train (True )
345
366
346
- eager_out = dmp (kjt_ft )
367
+ eager_out = dmp (kjt_ft , ff )
368
+
369
+ eager_loss = reduce_to_scalar_loss (eager_out )
370
+ eager_loss .backward ()
347
371
348
372
if torch_compile_backend is None :
349
373
return
350
374
351
375
##### COMPILE #####
376
+ run_compile_backward : bool = torch_compile_backend in ["aot_eager" , "inductor" ]
352
377
with dynamo_skipfiles_allow ("torchrec" ):
353
378
torch ._dynamo .config .capture_scalar_outputs = True
354
379
torch ._dynamo .config .capture_dynamic_output_shape_ops = True
@@ -357,8 +382,14 @@ def _test_compile_rank_fn(
357
382
backend = torch_compile_backend ,
358
383
fullgraph = True ,
359
384
)
360
- compile_out = opt_fn (kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb ))
361
- torch .testing .assert_close (eager_out , compile_out )
385
+ compile_out = opt_fn (
386
+ kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb ), compile_input_ff
387
+ )
388
+ torch .testing .assert_close (eager_out , compile_out , atol = 1e-3 , rtol = 1e-3 )
389
+ if run_compile_backward :
390
+ loss = reduce_to_scalar_loss (compile_out )
391
+ loss .backward ()
392
+
362
393
##### COMPILE END #####
363
394
364
395
##### NUMERIC CHECK #####
@@ -368,9 +399,20 @@ def _test_compile_rank_fn(
368
399
local_model_input = local_model_inputs [1 + i ].to (device )
369
400
kjt = local_model_input .idlist_features
370
401
kjt_ft = kjt_for_pt2_tracing (kjt , convert_to_vb = convert_to_vb )
371
- eager_out_i = dmp (kjt_ft )
372
- compile_out_i = opt_fn (kjt_ft )
373
- torch .testing .assert_close (eager_out_i , compile_out_i )
402
+ ff = local_model_input .float_features
403
+ ff .requires_grad = True
404
+ eager_out_i = dmp (kjt_ft , ff )
405
+ eager_loss_i = reduce_to_scalar_loss (eager_out_i )
406
+ eager_loss_i .backward ()
407
+
408
+ compile_input_ff = ff .detach ().clone ()
409
+ compile_out_i = opt_fn (kjt_ft , ff )
410
+ torch .testing .assert_close (
411
+ eager_out_i , compile_out_i , atol = 1e-3 , rtol = 1e-3
412
+ )
413
+ if run_compile_backward :
414
+ loss_i = torch ._dynamo .testing .reduce_to_scalar_loss (compile_out_i )
415
+ loss_i .backward ()
374
416
##### NUMERIC CHECK END #####
375
417
376
418
@@ -396,14 +438,14 @@ def disable_cuda_tf32(self) -> bool:
396
438
ShardingType .TABLE_WISE .value ,
397
439
_InputType .SINGLE_BATCH ,
398
440
_ConvertToVariableBatch .TRUE ,
399
- "eager " ,
441
+ "inductor " ,
400
442
),
401
443
(
402
444
_ModelType .EBC ,
403
445
ShardingType .COLUMN_WISE .value ,
404
446
_InputType .SINGLE_BATCH ,
405
447
_ConvertToVariableBatch .TRUE ,
406
- "eager " ,
448
+ "inductor " ,
407
449
),
408
450
(
409
451
_ModelType .EBC ,
@@ -412,6 +454,13 @@ def disable_cuda_tf32(self) -> bool:
412
454
_ConvertToVariableBatch .FALSE ,
413
455
"eager" ,
414
456
),
457
+ (
458
+ _ModelType .EBC ,
459
+ ShardingType .COLUMN_WISE .value ,
460
+ _InputType .SINGLE_BATCH ,
461
+ _ConvertToVariableBatch .FALSE ,
462
+ "eager" ,
463
+ ),
415
464
]
416
465
),
417
466
)
@@ -424,7 +473,7 @@ def test_compile_multiprocess(
424
473
str ,
425
474
_InputType ,
426
475
_ConvertToVariableBatch ,
427
- str ,
476
+ Optional [ str ] ,
428
477
],
429
478
) -> None :
430
479
model_type , sharding_type , input_type , tovb , compile_backend = (
0 commit comments