|
37 | 37 | kjt_for_pt2_tracing,
|
38 | 38 | register_fake_classes,
|
39 | 39 | )
|
| 40 | +from torchrec.sparse.jagged_tensor import _kt_regroup_arguments |
40 | 41 |
|
41 | 42 | try:
|
42 | 43 | # pyre-ignore
|
@@ -842,6 +843,33 @@ def test_permute_pooled_embs_split(self) -> None:
|
842 | 843 | inp = torch.randn(12, 3)
|
843 | 844 | _test_compile_fwd_bwd(m, inp, device)
|
844 | 845 |
|
| 846 | + @unittest.skipIf( |
| 847 | + torch.cuda.device_count() <= 1, |
| 848 | + "Not enough GPUs, this test requires at least two GPUs", |
| 849 | + ) |
| 850 | + def test_permute_multi_embedding(self) -> None: |
| 851 | + device = "cuda" |
| 852 | + batch_size = 16 |
| 853 | + |
| 854 | + def func(values, permutes, in_shapes, out_shapes, out_lengths): |
| 855 | + return torch.ops.fbgemm.permute_multi_embedding( |
| 856 | + values, permutes, in_shapes, out_shapes, out_lengths.tolist() |
| 857 | + ) |
| 858 | + |
| 859 | + keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] |
| 860 | + lengths = [[3, 4], [5, 6, 7], [8]] |
| 861 | + groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] |
| 862 | + values = [torch.randn(batch_size, sum(L), device=device) for L in lengths] |
| 863 | + for embs in values: |
| 864 | + torch._dynamo.mark_dynamic(embs, 0) |
| 865 | + torch._dynamo.mark_dynamic(embs, 1) |
| 866 | + permutes, in_shapes, out_shapes, out_lengths = _kt_regroup_arguments( |
| 867 | + values[0], keys, lengths, groups |
| 868 | + ) |
| 869 | + out_lengths = torch.tensor(out_lengths, device=device, dtype=torch.int32) |
| 870 | + inp = (values, permutes, in_shapes, out_shapes, out_lengths) |
| 871 | + _test_compile_fwd_bwd(func, inp, device, unpack_inp=True) |
| 872 | + |
845 | 873 | @unittest.skipIf(
|
846 | 874 | torch.cuda.device_count() < 1,
|
847 | 875 | "Not enough GPUs, this test requires at least one GPU",
|
|
0 commit comments