Skip to content

Commit 579fe9f

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
KJT permute torch export support (#1850)
Summary: Pull Request resolved: #1850 Support non-strict torch export for KJT permute method used by Sharded TorchRec Modules Reviewed By: IvanKobzarev Differential Revision: D55040353 fbshipit-source-id: 6f141a351e68611cc4268cc3346f3bb6bb370f34
1 parent e38728c commit 579fe9f

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

torchrec/distributed/tests/test_pt2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
163163
kjt.keys(),
164164
(kjt._values, kjt._lengths, indices),
165165
test_aot_inductor=False,
166+
test_pt2_ir_export=True,
166167
)
167168

168169
def test_kjt_length_per_key(self) -> None:

torchrec/sparse/jagged_tensor.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1831,15 +1831,16 @@ def permute(
18311831
permuted_keys: List[str] = []
18321832
permuted_stride_per_key_per_rank: List[List[int]] = []
18331833
permuted_length_per_key: List[int] = []
1834-
permuted_lengths_sum = 0
1834+
permuted_length_per_key_sum = 0
18351835
for index in indices:
18361836
key = self.keys()[index]
18371837
permuted_keys.append(key)
18381838
permuted_stride_per_key_per_rank.append(
18391839
self.stride_per_key_per_rank()[index]
18401840
)
18411841
permuted_length_per_key.append(length_per_key[index])
1842-
permuted_lengths_sum += length_per_key[index]
1842+
if not is_non_strict_exporting():
1843+
permuted_length_per_key_sum += length_per_key[index]
18431844
if self.variable_stride_per_key():
18441845
length_per_key_tensor = _pin_and_move(
18451846
torch.tensor(self.length_per_key()), self.device()
@@ -1860,6 +1861,19 @@ def permute(
18601861
self.weights_or_none(),
18611862
)
18621863
else:
1864+
if not torch.jit.is_scripting() and is_non_strict_exporting():
1865+
permuted_length_per_key_sum = torch.sum(
1866+
torch._refs.tensor(
1867+
permuted_length_per_key,
1868+
dtype=torch.int32,
1869+
device=torch.device("cpu"),
1870+
pin_memory=False,
1871+
requires_grad=False,
1872+
)
1873+
).item()
1874+
1875+
torch._check(permuted_length_per_key_sum > 0)
1876+
18631877
(
18641878
permuted_lengths,
18651879
permuted_values,
@@ -1869,7 +1883,7 @@ def permute(
18691883
self.lengths().view(len(self._keys), -1),
18701884
self.values(),
18711885
self.weights_or_none(),
1872-
permuted_lengths_sum,
1886+
permuted_length_per_key_sum,
18731887
)
18741888
stride, optional_permuted_stride_per_key_per_rank = (
18751889
(None, permuted_stride_per_key_per_rank)

0 commit comments

Comments
 (0)