Skip to content

Commit dfadd39

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add PT2 support for permute_multi_embedding (#2381)
Summary: Pull Request resolved: #2381 It looks like test_pt2 already passed. not sure why the test can't capture the PT2 incompatibility in the op. graph breaks: P1557581728 https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmphgx6wM/rank_0/failures_and_restarts.html Differential Revision: D62226292
1 parent 0bc1baa commit dfadd39

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

torchrec/distributed/tests/test_pt2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
kjt_for_pt2_tracing,
3838
register_fake_classes,
3939
)
40+
from torchrec.sparse.jagged_tensor import _kt_regroup_arguments
4041

4142
try:
4243
# pyre-ignore
@@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None:
842843
inp = torch.randn(12, 3)
843844
_test_compile_fwd_bwd(m, inp, device)
844845

846+
@unittest.skipIf(
847+
torch.cuda.device_count() <= 1,
848+
"Not enough GPUs, this test requires at least two GPUs",
849+
)
850+
def test_permute_multi_embedding(self) -> None:
851+
device = "cuda"
852+
batch_size = 16
853+
854+
def func(values, permutes, in_shapes, out_shapes, out_lengths):
855+
return torch.ops.fbgemm.permute_multi_embedding(
856+
values, permutes, in_shapes, out_shapes, out_lengths.tolist()
857+
)
858+
859+
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
860+
lengths = [[3, 4], [5, 6, 7], [8]]
861+
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
862+
values = [torch.randn(batch_size, sum(L), device=device) for L in lengths]
863+
for embs in values:
864+
torch._dynamo.mark_dynamic(embs, 0)
865+
torch._dynamo.mark_dynamic(embs, 1)
866+
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
867+
values[0], keys, lengths, groups
868+
)
869+
out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32)
870+
inp = (values, permutes, in_shapes, out_shapes, out_lengths)
871+
_test_compile_fwd_bwd(func, inp, device, unpack_inp=True)
872+
845873
@unittest.skipIf(
846874
torch.cuda.device_count() < 1,
847875
"Not enough GPUs, this test requires at least one GPU",

0 commit comments

Comments
 (0)