Skip to content

Commit 4bc15d6

Browse files
Ivan Kobzarevfacebook-github-bot
authored andcommitted
Enable inductor compilation for EBC-VB (#2125)
Summary: Pull Request resolved: #2125 Enabling inductor compilation tests for VB-path. Adding non-VB testing for CW sharding. (non-VB inductor compilation needs more changes to land) Reviewed By: PaulZhang12 Differential Revision: D58672604
1 parent b0adab6 commit 4bc15d6

File tree

1 file changed

+64
-15
lines changed

1 file changed

+64
-15
lines changed

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torchrec
2020
import torchrec.pt2.checks
2121
from hypothesis import given, settings, strategies as st, Verbosity
22+
from torch._dynamo.testing import reduce_to_scalar_loss
2223
from torchrec.distributed.embedding import EmbeddingCollectionSharder
2324
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
2425
from torchrec.distributed.fbgemm_qcomm_codec import QCommsConfig
@@ -56,6 +57,7 @@
5657
from torchrec.pt2.utils import kjt_for_pt2_tracing
5758
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
5859

60+
5961
try:
6062
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
6163
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
@@ -139,17 +141,22 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
139141

140142

141143
def _gen_model(test_model_type: _ModelType, mi: TestModelInfo) -> torch.nn.Module:
144+
emb_dim: int = max(t.embedding_dim for t in mi.tables)
142145
if test_model_type == _ModelType.EBC:
143146

144147
class M_ebc(torch.nn.Module):
145148
def __init__(self, ebc: EmbeddingBagCollection) -> None:
146149
super().__init__()
147150
self._ebc = ebc
151+
self._linear = torch.nn.Linear(
152+
mi.num_float_features, emb_dim, device=mi.dense_device
153+
)
148154

149-
def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
155+
def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor:
150156
kt: KeyedTensor = self._ebc(x)
151157
v = kt.values()
152-
return torch.sigmoid(torch.mean(v, dim=1))
158+
y = self._linear(y)
159+
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))
153160

154161
return M_ebc(
155162
EmbeddingBagCollection(
@@ -164,10 +171,15 @@ class M_fpebc(torch.nn.Module):
164171
def __init__(self, fpebc: FeatureProcessedEmbeddingBagCollection) -> None:
165172
super().__init__()
166173
self._fpebc = fpebc
174+
self._linear = torch.nn.Linear(
175+
mi.num_float_features, emb_dim, device=mi.dense_device
176+
)
167177

168-
def forward(self, x: KeyedJaggedTensor) -> torch.Tensor:
178+
def forward(self, x: KeyedJaggedTensor, y: torch.Tensor) -> torch.Tensor:
169179
kt: KeyedTensor = self._fpebc(x)
170-
return kt.values()
180+
v = kt.values()
181+
y = self._linear(y)
182+
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))
171183

172184
return M_fpebc(
173185
FeatureProcessedEmbeddingBagCollection(
@@ -187,9 +199,13 @@ def __init__(self, ec: EmbeddingCollection) -> None:
187199
super().__init__()
188200
self._ec = ec
189201

190-
def forward(self, x: KeyedJaggedTensor) -> List[JaggedTensor]:
202+
def forward(
203+
self, x: KeyedJaggedTensor, y: torch.Tensor
204+
) -> List[JaggedTensor]:
191205
d: Dict[str, JaggedTensor] = self._ec(x)
192-
return list(d.values())
206+
v = torch.stack(d.values(), dim=0).sum(dim=0)
207+
y = self._linear(y)
208+
return torch.mul(torch.mean(v, dim=1), torch.mean(y, dim=1))
193209

194210
return M_ec(
195211
EmbeddingCollection(
@@ -307,6 +323,7 @@ def _test_compile_rank_fn(
307323
# pyre-ignore
308324
sharders=sharders,
309325
device=device,
326+
init_data_parallel=False,
310327
)
311328

312329
if input_type == _InputType.VARIABLE_BATCH:
@@ -336,19 +353,27 @@ def _test_compile_rank_fn(
336353
local_model_input = local_model_inputs[0].to(device)
337354

338355
kjt = local_model_input.idlist_features
356+
ff = local_model_input.float_features
357+
ff.requires_grad = True
339358
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
340359

360+
compile_input_ff = ff.clone().detach()
361+
341362
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
342363
torchrec.pt2.checks.set_use_torchdynamo_compiling_path(True)
343364

344365
dmp.train(True)
345366

346-
eager_out = dmp(kjt_ft)
367+
eager_out = dmp(kjt_ft, ff)
368+
369+
eager_loss = reduce_to_scalar_loss(eager_out)
370+
eager_loss.backward()
347371

348372
if torch_compile_backend is None:
349373
return
350374

351375
##### COMPILE #####
376+
run_compile_backward: bool = torch_compile_backend in ["aot_eager", "inductor"]
352377
with dynamo_skipfiles_allow("torchrec"):
353378
torch._dynamo.config.capture_scalar_outputs = True
354379
torch._dynamo.config.capture_dynamic_output_shape_ops = True
@@ -357,8 +382,14 @@ def _test_compile_rank_fn(
357382
backend=torch_compile_backend,
358383
fullgraph=True,
359384
)
360-
compile_out = opt_fn(kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb))
361-
torch.testing.assert_close(eager_out, compile_out)
385+
compile_out = opt_fn(
386+
kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb), compile_input_ff
387+
)
388+
torch.testing.assert_close(eager_out, compile_out, atol=1e-3, rtol=1e-3)
389+
if run_compile_backward:
390+
loss = reduce_to_scalar_loss(compile_out)
391+
loss.backward()
392+
362393
##### COMPILE END #####
363394

364395
##### NUMERIC CHECK #####
@@ -368,9 +399,20 @@ def _test_compile_rank_fn(
368399
local_model_input = local_model_inputs[1 + i].to(device)
369400
kjt = local_model_input.idlist_features
370401
kjt_ft = kjt_for_pt2_tracing(kjt, convert_to_vb=convert_to_vb)
371-
eager_out_i = dmp(kjt_ft)
372-
compile_out_i = opt_fn(kjt_ft)
373-
torch.testing.assert_close(eager_out_i, compile_out_i)
402+
ff = local_model_input.float_features
403+
ff.requires_grad = True
404+
eager_out_i = dmp(kjt_ft, ff)
405+
eager_loss_i = reduce_to_scalar_loss(eager_out_i)
406+
eager_loss_i.backward()
407+
408+
compile_input_ff = ff.detach().clone()
409+
compile_out_i = opt_fn(kjt_ft, ff)
410+
torch.testing.assert_close(
411+
eager_out_i, compile_out_i, atol=1e-3, rtol=1e-3
412+
)
413+
if run_compile_backward:
414+
loss_i = torch._dynamo.testing.reduce_to_scalar_loss(compile_out_i)
415+
loss_i.backward()
374416
##### NUMERIC CHECK END #####
375417

376418

@@ -396,14 +438,14 @@ def disable_cuda_tf32(self) -> bool:
396438
ShardingType.TABLE_WISE.value,
397439
_InputType.SINGLE_BATCH,
398440
_ConvertToVariableBatch.TRUE,
399-
"eager",
441+
"inductor",
400442
),
401443
(
402444
_ModelType.EBC,
403445
ShardingType.COLUMN_WISE.value,
404446
_InputType.SINGLE_BATCH,
405447
_ConvertToVariableBatch.TRUE,
406-
"eager",
448+
"inductor",
407449
),
408450
(
409451
_ModelType.EBC,
@@ -412,6 +454,13 @@ def disable_cuda_tf32(self) -> bool:
412454
_ConvertToVariableBatch.FALSE,
413455
"eager",
414456
),
457+
(
458+
_ModelType.EBC,
459+
ShardingType.COLUMN_WISE.value,
460+
_InputType.SINGLE_BATCH,
461+
_ConvertToVariableBatch.FALSE,
462+
"eager",
463+
),
415464
]
416465
),
417466
)
@@ -424,7 +473,7 @@ def test_compile_multiprocess(
424473
str,
425474
_InputType,
426475
_ConvertToVariableBatch,
427-
str,
476+
Optional[str],
428477
],
429478
) -> None:
430479
model_type, sharding_type, input_type, tovb, compile_backend = (

0 commit comments

Comments
 (0)