Skip to content

Commit fb0692c

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
KJT split torch export support
Summary: torch.export support for KJT split Differential Revision: D53545161
1 parent b7aeef1 commit fb0692c

File tree

2 files changed

+78
-33
lines changed

2 files changed

+78
-33
lines changed

torchrec/distributed/tests/test_pt2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,18 @@ def _test_kjt_input_module(
137137

138138
def test_kjt_split(self) -> None:
139139
class M(torch.nn.Module):
140-
def forward(self, kjt: KeyedJaggedTensor, segments: List[int]):
141-
return kjt.split(segments)
140+
def forward(self, kjt: KeyedJaggedTensor):
141+
return kjt.split([1, 2, 1])
142142

143143
kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
144144
segments: List[int] = [1, 2, 1]
145145
self._test_kjt_input_module(
146146
M(),
147147
kjt.keys(),
148-
(kjt._values, kjt._lengths, segments),
148+
(kjt._values, kjt._lengths),
149149
test_aot_inductor=False,
150+
test_dynamo=False,
151+
test_pt2_ir_export=True,
150152
)
151153

152154
def test_kjt_permute(self) -> None:

torchrec/sparse/jagged_tensor.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,39 +1729,82 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
17291729
else:
17301730
split_length_per_key = _length_per_key[start:end]
17311731

1732-
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1733-
# Checks for dynamo dynamic shapes tracing
1734-
torch._check_is_size(start_offset)
1735-
torch._check_is_size(end_offset)
1736-
torch._check_is_size(end_offset - start_offset)
1732+
if is_non_strict_exporting():
1733+
sz = sum(split_length_per_key)
1734+
1735+
torch._check_is_size(sz)
17371736
torch._check(start_offset <= self._values.size(0))
1738-
torch._check(end_offset <= self._values.size(0))
1739-
torch._check(end_offset >= start_offset)
1737+
torch._check(sz <= self._values.size(0))
1738+
torch._check_is_size(start_offset)
17401739

1741-
split_list.append(
1742-
KeyedJaggedTensor(
1743-
keys=keys,
1744-
values=self._values[start_offset:end_offset],
1745-
weights=(
1746-
None
1747-
if self.weights_or_none() is None
1748-
else self.weights()[start_offset:end_offset]
1749-
),
1750-
lengths=self.lengths()[
1751-
self.lengths_offset_per_key()[
1752-
start
1753-
] : self.lengths_offset_per_key()[end]
1754-
],
1755-
offsets=None,
1756-
stride=stride,
1757-
stride_per_key_per_rank=stride_per_key_per_rank,
1758-
length_per_key=split_length_per_key,
1759-
offset_per_key=None,
1760-
index_per_key=None,
1761-
jt_dict=None,
1762-
inverse_indices=None,
1740+
# Why are below 3 needed?
1741+
torch._check(sz != 0)
1742+
torch._check(sz != -1)
1743+
torch._check(sz != 1)
1744+
1745+
torch._check(start_offset + sz <= self._values.size(0))
1746+
torch._check(start_offset + sz >= 0)
1747+
1748+
lengths_start = self.lengths_offset_per_key()[start]
1749+
lengths_sz = self.lengths_offset_per_key()[end] - lengths_start
1750+
1751+
_lengths = torch.narrow(
1752+
self.lengths(), 0, lengths_start, lengths_sz
1753+
)
1754+
split_list.append(
1755+
KeyedJaggedTensor(
1756+
keys=keys,
1757+
values=torch.narrow(self._values, 0, start_offset, sz),
1758+
weights=(
1759+
None
1760+
if self.weights_or_none() is None
1761+
else torch.narrow(self.weights(), 0, start_offset, sz)
1762+
),
1763+
lengths=_lengths,
1764+
offsets=None,
1765+
stride=stride,
1766+
stride_per_key_per_rank=stride_per_key_per_rank,
1767+
length_per_key=split_length_per_key,
1768+
offset_per_key=None,
1769+
index_per_key=None,
1770+
jt_dict=None,
1771+
inverse_indices=None,
1772+
)
1773+
)
1774+
else:
1775+
if not torch.jit.is_scripting() and is_torchdynamo_compiling():
1776+
# Checks for dynamo dynamic shapes tracing
1777+
torch._check_is_size(start_offset)
1778+
torch._check_is_size(end_offset)
1779+
torch._check_is_size(end_offset - start_offset)
1780+
torch._check(start_offset <= self._values.size(0))
1781+
torch._check(end_offset <= self._values.size(0))
1782+
torch._check(end_offset >= start_offset)
1783+
1784+
split_list.append(
1785+
KeyedJaggedTensor(
1786+
keys=keys,
1787+
values=self._values[start_offset:end_offset],
1788+
weights=(
1789+
None
1790+
if self.weights_or_none() is None
1791+
else self.weights()[start_offset:end_offset]
1792+
),
1793+
lengths=self.lengths()[
1794+
self.lengths_offset_per_key()[
1795+
start
1796+
] : self.lengths_offset_per_key()[end]
1797+
],
1798+
offsets=None,
1799+
stride=stride,
1800+
stride_per_key_per_rank=stride_per_key_per_rank,
1801+
length_per_key=split_length_per_key,
1802+
offset_per_key=None,
1803+
index_per_key=None,
1804+
jt_dict=None,
1805+
inverse_indices=None,
1806+
)
17631807
)
1764-
)
17651808
start = end
17661809
start_offset = end_offset
17671810
return split_list

0 commit comments

Comments
 (0)