Skip to content

add PT2 support for permute_multi_embedding #2381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
kjt_for_pt2_tracing,
register_fake_classes,
)
from torchrec.sparse.jagged_tensor import _kt_regroup_arguments

try:
# pyre-ignore
Expand Down Expand Up @@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None:
inp = torch.randn(12, 3)
_test_compile_fwd_bwd(m, inp, device)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs, this test requires at least two GPUs",
)
def test_permute_multi_embedding(self) -> None:
device = "cuda"
batch_size = 16

def func(values, permutes, in_shapes, out_shapes, out_lengths):
return torch.ops.fbgemm.permute_multi_embedding(
values, permutes, in_shapes, out_shapes, out_lengths.tolist()
)

keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[3, 4], [5, 6, 7], [8]]
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
values = [torch.randn(batch_size, sum(L), device=device) for L in lengths]
for embs in values:
torch._dynamo.mark_dynamic(embs, 0)
torch._dynamo.mark_dynamic(embs, 1)
permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments(
values[0], keys, lengths, groups
)
out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32)
inp = (values, permutes, in_shapes, out_shapes, out_lengths)
_test_compile_fwd_bwd(func, inp, device, unpack_inp=True)

@unittest.skipIf(
torch.cuda.device_count() < 1,
"Not enough GPUs, this test requires at least one GPU",
Expand Down
Loading