Skip to content

Commit bc6957d

Browse files
joshuadengfacebook-github-bot
authored andcommitted
Add VBE inverse indices pass through support to KJT concat (#2366)
Summary: Pull Request resolved: #2366 join multiple inverse indices when concatting VBE kjts with inverse indices Reviewed By: iamzainhuda Differential Revision: D60782914 fbshipit-source-id: 21ff4f16db42e443c0660957ff665e696aa96e52
1 parent ff8e26e commit bc6957d

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,9 @@ def concat(
16561656
length_list: List[torch.Tensor] = []
16571657
stride_per_key_per_rank: List[List[int]] = []
16581658
stride: Optional[int] = None
1659+
inv_idx_keys: List[str] = []
1660+
inv_idx_tensors: List[torch.Tensor] = []
1661+
16591662
variable_stride_per_key_list = [
16601663
kjt.variable_stride_per_key() for kjt in kjt_list
16611664
]
@@ -1664,7 +1667,7 @@ def concat(
16641667
), "variable stride per key must be consistent for all KJTs"
16651668
variable_stride_per_key = all(variable_stride_per_key_list)
16661669

1667-
for kjt in kjt_list:
1670+
for i, kjt in enumerate(kjt_list):
16681671
curr_is_weighted: bool = kjt.weights_or_none() is not None
16691672
if is_weighted != curr_is_weighted:
16701673
raise ValueError("Can't merge weighted KJT with unweighted KJT")
@@ -1686,6 +1689,16 @@ def concat(
16861689
stride = kjt.stride()
16871690
else:
16881691
assert stride == kjt.stride(), "strides must be consistent for all KJTs"
1692+
if kjt.inverse_indices_or_none() is not None:
1693+
assert (
1694+
len(inv_idx_tensors) == i
1695+
), "inverse indices must be consistent for all KJTs"
1696+
inv_idx_keys += kjt.inverse_indices()[0]
1697+
inv_idx_tensors.append(kjt.inverse_indices()[1])
1698+
else:
1699+
assert (
1700+
len(inv_idx_tensors) == 0
1701+
), "inverse indices must be consistent for all KJTs"
16891702

16901703
return KeyedJaggedTensor(
16911704
keys=keys,
@@ -1697,6 +1710,11 @@ def concat(
16971710
stride_per_key_per_rank if variable_stride_per_key else None
16981711
),
16991712
length_per_key=length_per_key if has_length_per_key else None,
1713+
inverse_indices=(
1714+
(inv_idx_keys, torch.cat(inv_idx_tensors))
1715+
if len(inv_idx_tensors) == len(kjt_list)
1716+
else None
1717+
),
17001718
)
17011719

17021720
@staticmethod

0 commit comments

Comments
 (0)