Skip to content

Commit e9fd2c9

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Faster KJT init (#2369)
Summary: Pull Request resolved: #2369 Pull Request resolved: #2231 To improve inference, we want to make creating a KJT as cheap as possible, which means the init method is nothing more than a attribute setter. All other fields are calculated lazily. This is practicularly important wrt jit script and moving between compilation units. Differential Revision: D62312329
1 parent 77e8b52 commit e9fd2c9

File tree

2 files changed

+154
-48
lines changed

2 files changed

+154
-48
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 151 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,22 @@ def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor:
115115
return weights
116116

117117

118+
def _get_lengths_offset_per_key_or_throw(
119+
lengths_offset_per_key: Optional[List[int]],
120+
) -> List[int]:
121+
assert (
122+
lengths_offset_per_key is not None
123+
), "This (Keyed)JaggedTensor doesn't have lengths_offset_per_key."
124+
return lengths_offset_per_key
125+
126+
127+
def _get_stride_per_key_or_throw(stride_per_key: Optional[List[int]]) -> List[int]:
128+
assert (
129+
stride_per_key is not None
130+
), "This (Keyed)JaggedTensor doesn't have stride_per_key."
131+
return stride_per_key
132+
133+
118134
def _get_inverse_indices_or_throw(
119135
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
120136
) -> Tuple[List[str], torch.Tensor]:
@@ -891,9 +907,9 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
891907

892908

893909
def _assert_tensor_has_no_elements_or_has_integers(
894-
tensor: torch.Tensor, tensor_name: str
910+
tensor: Optional[torch.Tensor], tensor_name: str
895911
) -> None:
896-
if is_torchdynamo_compiling():
912+
if is_torchdynamo_compiling() or tensor is None:
897913
# Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes.
898914
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
899915
return
@@ -921,10 +937,13 @@ def _maybe_compute_stride_kjt(
921937
stride: Optional[int],
922938
lengths: Optional[torch.Tensor],
923939
offsets: Optional[torch.Tensor],
940+
stride_per_key_per_rank: Optional[List[List[int]]],
924941
) -> int:
925942
if stride is None:
926943
if len(keys) == 0:
927944
stride = 0
945+
elif stride_per_key_per_rank is not None and len(stride_per_key_per_rank) > 0:
946+
stride = max([sum(s) for s in stride_per_key_per_rank])
928947
elif offsets is not None and offsets.numel() > 0:
929948
stride = (offsets.numel() - 1) // len(keys)
930949
elif lengths is not None:
@@ -1467,6 +1486,50 @@ def _check_attributes(
14671486
return True
14681487

14691488

1489+
def _maybe_compute_lengths_offset_per_key(
1490+
lengths_offset_per_key: Optional[List[int]],
1491+
stride_per_key: Optional[List[int]],
1492+
stride: Optional[int],
1493+
keys: List[str],
1494+
) -> Optional[List[int]]:
1495+
if lengths_offset_per_key is not None:
1496+
return lengths_offset_per_key
1497+
elif stride_per_key is not None:
1498+
return _cumsum(stride_per_key)
1499+
elif stride is not None:
1500+
return _cumsum([stride] * len(keys))
1501+
else:
1502+
return None
1503+
1504+
1505+
def _maybe_compute_stride_per_key(
1506+
stride_per_key: Optional[List[int]],
1507+
stride_per_key_per_rank: Optional[List[List[int]]],
1508+
stride: Optional[int],
1509+
keys: List[str],
1510+
) -> Optional[List[int]]:
1511+
if stride_per_key is not None:
1512+
return stride_per_key
1513+
elif stride_per_key_per_rank is not None:
1514+
return [sum(s) for s in stride_per_key_per_rank]
1515+
elif stride is not None:
1516+
return [stride] * len(keys)
1517+
else:
1518+
return None
1519+
1520+
1521+
def _maybe_compute_variable_stride_per_key(
1522+
variable_stride_per_key: Optional[bool],
1523+
stride_per_key_per_rank: Optional[List[List[int]]],
1524+
) -> bool:
1525+
if variable_stride_per_key is not None:
1526+
return variable_stride_per_key
1527+
elif stride_per_key_per_rank is not None:
1528+
return True
1529+
else:
1530+
return False
1531+
1532+
14701533
class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
14711534
"""Represents an (optionally weighted) keyed jagged tensor.
14721535
@@ -1540,62 +1603,57 @@ def __init__(
15401603
stride: Optional[int] = None,
15411604
stride_per_key_per_rank: Optional[List[List[int]]] = None,
15421605
# Below exposed to ensure torch.script-able
1606+
stride_per_key: Optional[List[int]] = None,
15431607
length_per_key: Optional[List[int]] = None,
1608+
lengths_offset_per_key: Optional[List[int]] = None,
15441609
offset_per_key: Optional[List[int]] = None,
15451610
index_per_key: Optional[Dict[str, int]] = None,
15461611
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
15471612
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
15481613
) -> None:
1614+
"""
1615+
This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
1616+
It is important only to assign attributes here or do input checks to support various
1617+
internal inference optimizations. By convention the attirbute is named same as input arg, just
1618+
with leading underscore
1619+
"""
15491620
self._keys: List[str] = keys
15501621
self._values: torch.Tensor = values
15511622
self._weights: Optional[torch.Tensor] = weights
1552-
if offsets is not None:
1553-
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
1554-
if lengths is not None:
1555-
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
15561623
self._lengths: Optional[torch.Tensor] = lengths
15571624
self._offsets: Optional[torch.Tensor] = offsets
1558-
1559-
self._stride_per_key_per_rank: List[List[int]] = []
1560-
self._stride_per_key: List[int] = []
1561-
self._variable_stride_per_key: bool = False
1562-
self._stride: int = -1
1563-
1564-
if stride_per_key_per_rank is not None:
1565-
self._stride_per_key_per_rank = stride_per_key_per_rank
1566-
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
1567-
self._variable_stride_per_key = True
1568-
if stride is not None:
1569-
self._stride = stride
1570-
else:
1571-
self._stride = (
1572-
max(self._stride_per_key) if len(self._stride_per_key) > 0 else 0
1573-
)
1574-
else:
1575-
stride = _maybe_compute_stride_kjt(keys, stride, lengths, offsets)
1576-
self._stride = stride
1577-
self._stride_per_key_per_rank = [[stride]] * len(self._keys)
1578-
self._stride_per_key = [sum(s) for s in self._stride_per_key_per_rank]
1579-
1580-
# lazy fields
1625+
self._stride: Optional[int] = stride
1626+
self._stride_per_key_per_rank: Optional[List[List[int]]] = (
1627+
stride_per_key_per_rank
1628+
)
1629+
self._stride_per_key: Optional[List[int]] = stride_per_key
15811630
self._length_per_key: Optional[List[int]] = length_per_key
15821631
self._offset_per_key: Optional[List[int]] = offset_per_key
1632+
self._lengths_offset_per_key: Optional[List[int]] = lengths_offset_per_key
15831633
self._index_per_key: Optional[Dict[str, int]] = index_per_key
15841634
self._jt_dict: Optional[Dict[str, JaggedTensor]] = jt_dict
15851635
self._inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = (
15861636
inverse_indices
15871637
)
1588-
self._lengths_offset_per_key: List[int] = []
15891638

1590-
self._init_pt2_checks()
1639+
# legacy attribute, for backward compatabilibity
1640+
self._variable_stride_per_key: Optional[bool] = None
1641+
1642+
# validation logic
1643+
if not torch.jit.is_scripting():
1644+
_assert_tensor_has_no_elements_or_has_integers(offsets, "offsets")
1645+
_assert_tensor_has_no_elements_or_has_integers(lengths, "lengths")
1646+
self._init_pt2_checks()
15911647

15921648
def _init_pt2_checks(self) -> None:
15931649
if torch.jit.is_scripting() or not is_torchdynamo_compiling():
15941650
return
1595-
1596-
pt2_checks_all_is_size(self._stride_per_key)
1597-
for s in self._stride_per_key_per_rank:
1598-
pt2_checks_all_is_size(s)
1651+
if self._stride_per_key is not None:
1652+
pt2_checks_all_is_size(self._stride_per_key)
1653+
if self._stride_per_key_per_rank is not None:
1654+
# pyre-ignore [16]
1655+
for s in self._stride_per_key_per_rank:
1656+
pt2_checks_all_is_size(s)
15991657

16001658
@staticmethod
16011659
def from_offsets_sync(
@@ -1863,16 +1921,34 @@ def weights_or_none(self) -> Optional[torch.Tensor]:
18631921
return self._weights
18641922

18651923
def stride(self) -> int:
1866-
return self._stride
1924+
stride = _maybe_compute_stride_kjt(
1925+
self._keys,
1926+
self._stride,
1927+
self._lengths,
1928+
self._offsets,
1929+
self._stride_per_key_per_rank,
1930+
)
1931+
self._stride = stride
1932+
return stride
18671933

18681934
def stride_per_key(self) -> List[int]:
1869-
return self._stride_per_key
1935+
stride_per_key = _maybe_compute_stride_per_key(
1936+
self._stride_per_key,
1937+
self._stride_per_key_per_rank,
1938+
self.stride(),
1939+
self._keys,
1940+
)
1941+
self._stride_per_key = stride_per_key
1942+
return _get_stride_per_key_or_throw(stride_per_key)
18701943

18711944
def stride_per_key_per_rank(self) -> List[List[int]]:
1872-
return self._stride_per_key_per_rank
1945+
stride_per_key_per_rank = self._stride_per_key_per_rank
1946+
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
18731947

18741948
def variable_stride_per_key(self) -> bool:
1875-
return self._variable_stride_per_key
1949+
if self._variable_stride_per_key is not None:
1950+
return self._variable_stride_per_key
1951+
return self._stride_per_key_per_rank is not None
18761952

18771953
def inverse_indices(self) -> Tuple[List[str], torch.Tensor]:
18781954
return _get_inverse_indices_or_throw(self._inverse_indices)
@@ -1925,9 +2001,20 @@ def offset_per_key_or_none(self) -> Optional[List[int]]:
19252001
return self._offset_per_key
19262002

19272003
def lengths_offset_per_key(self) -> List[int]:
1928-
if not self._lengths_offset_per_key:
1929-
self._lengths_offset_per_key = _cumsum(self.stride_per_key())
1930-
return self._lengths_offset_per_key
2004+
if self.variable_stride_per_key():
2005+
_lengths_offset_per_key = _maybe_compute_lengths_offset_per_key(
2006+
self._lengths_offset_per_key,
2007+
self.stride_per_key(),
2008+
None,
2009+
self._keys,
2010+
)
2011+
else:
2012+
_lengths_offset_per_key = _maybe_compute_lengths_offset_per_key(
2013+
self._lengths_offset_per_key, None, self.stride(), self._keys
2014+
)
2015+
2016+
self._lengths_offset_per_key = _lengths_offset_per_key
2017+
return _get_lengths_offset_per_key_or_throw(_lengths_offset_per_key)
19312018

19322019
def index_per_key(self) -> Dict[str, int]:
19332020
return self._key_indices()
@@ -1958,7 +2045,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
19582045
offsets=self._offsets,
19592046
stride=self._stride,
19602047
stride_per_key_per_rank=stride_per_key_per_rank,
2048+
stride_per_key=None,
19612049
length_per_key=self._length_per_key,
2050+
lengths_offset_per_key=None,
19622051
offset_per_key=self._offset_per_key,
19632052
index_per_key=self._index_per_key,
19642053
jt_dict=self._jt_dict,
@@ -1992,7 +2081,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
19922081
),
19932082
stride=self._stride,
19942083
stride_per_key_per_rank=stride_per_key_per_rank,
2084+
stride_per_key=None,
19952085
length_per_key=None,
2086+
lengths_offset_per_key=None,
19962087
offset_per_key=None,
19972088
index_per_key=None,
19982089
jt_dict=None,
@@ -2036,7 +2127,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
20362127
offsets=None,
20372128
stride=self._stride,
20382129
stride_per_key_per_rank=stride_per_key_per_rank,
2130+
stride_per_key=None,
20392131
length_per_key=split_length_per_key,
2132+
lengths_offset_per_key=None,
20402133
offset_per_key=None,
20412134
index_per_key=None,
20422135
jt_dict=None,
@@ -2070,7 +2163,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
20702163
offsets=None,
20712164
stride=self._stride,
20722165
stride_per_key_per_rank=stride_per_key_per_rank,
2166+
stride_per_key=None,
20732167
length_per_key=split_length_per_key,
2168+
lengths_offset_per_key=None,
20742169
offset_per_key=None,
20752170
index_per_key=None,
20762171
jt_dict=None,
@@ -2098,10 +2193,11 @@ def permute(
20982193
for index in indices:
20992194
key = self.keys()[index]
21002195
permuted_keys.append(key)
2101-
permuted_stride_per_key_per_rank.append(
2102-
self.stride_per_key_per_rank()[index]
2103-
)
21042196
permuted_length_per_key.append(length_per_key[index])
2197+
if self.variable_stride_per_key():
2198+
permuted_stride_per_key_per_rank.append(
2199+
self.stride_per_key_per_rank()[index]
2200+
)
21052201

21062202
permuted_length_per_key_sum = sum(permuted_length_per_key)
21072203
if not torch.jit.is_scripting() and is_non_strict_exporting():
@@ -2164,7 +2260,9 @@ def permute(
21642260
offsets=None,
21652261
stride=self._stride,
21662262
stride_per_key_per_rank=stride_per_key_per_rank,
2263+
stride_per_key=None,
21672264
length_per_key=permuted_length_per_key if len(permuted_keys) > 0 else None,
2265+
lengths_offset_per_key=None,
21682266
offset_per_key=None,
21692267
index_per_key=None,
21702268
jt_dict=None,
@@ -2184,7 +2282,9 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
21842282
offsets=None,
21852283
stride=self._stride,
21862284
stride_per_key_per_rank=stride_per_key_per_rank,
2285+
stride_per_key=None,
21872286
length_per_key=self.length_per_key(),
2287+
lengths_offset_per_key=None,
21882288
offset_per_key=None,
21892289
index_per_key=None,
21902290
jt_dict=None,
@@ -2304,8 +2404,10 @@ def to(
23042404
self._stride_per_key_per_rank if self.variable_stride_per_key() else None
23052405
)
23062406
length_per_key = self._length_per_key
2407+
lengths_offset_per_key = self._lengths_offset_per_key
23072408
offset_per_key = self._offset_per_key
23082409
index_per_key = self._index_per_key
2410+
stride_per_key = self._stride_per_key
23092411
jt_dict = self._jt_dict
23102412
inverse_indices = self._inverse_indices
23112413
if inverse_indices is not None:
@@ -2337,7 +2439,9 @@ def to(
23372439
),
23382440
stride=self._stride,
23392441
stride_per_key_per_rank=stride_per_key_per_rank,
2442+
stride_per_key=stride_per_key,
23402443
length_per_key=length_per_key,
2444+
lengths_offset_per_key=lengths_offset_per_key,
23412445
offset_per_key=offset_per_key,
23422446
index_per_key=index_per_key,
23432447
jt_dict=jt_dict,
@@ -2387,7 +2491,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
23872491
offsets=offsets.pin_memory() if offsets is not None else None,
23882492
stride=self._stride,
23892493
stride_per_key_per_rank=stride_per_key_per_rank,
2494+
stride_per_key=self._stride_per_key,
23902495
length_per_key=self._length_per_key,
2496+
lengths_offset_per_key=self._lengths_offset_per_key,
23912497
offset_per_key=self._offset_per_key,
23922498
index_per_key=self._index_per_key,
23932499
jt_dict=None,

torchrec/sparse/tests/test_jagged_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2121,13 +2121,13 @@ def forward(
21212121
lengths=input.lengths(),
21222122
offsets=input.offsets(),
21232123
)
2124-
return output, output._stride
2124+
return output, output.stride()
21252125

21262126
# Case 3: KeyedJaggedTensor is used as both an input and an output of the root module.
21272127
m = ModuleUseKeyedJaggedTensorAsInputAndOutput()
21282128
gm = symbolic_trace(m)
21292129
FileCheck().check("KeyedJaggedTensor").check("keys()").check("values()").check(
2130-
"._stride"
2130+
"stride"
21312131
).run(gm.code)
21322132
input = KeyedJaggedTensor.from_offsets_sync(
21332133
values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
@@ -2185,7 +2185,7 @@ def forward(
21852185
lengths: torch.Tensor,
21862186
) -> Tuple[KeyedJaggedTensor, int]:
21872187
output = KeyedJaggedTensor(keys, values, weights, lengths)
2188-
return output, output._stride
2188+
return output, output.stride()
21892189

21902190
# Case 1: KeyedJaggedTensor is only used as an output of the root module.
21912191
m = ModuleUseKeyedJaggedTensorAsOutput()

0 commit comments

Comments
 (0)