@@ -1656,6 +1656,9 @@ def concat(
1656
1656
length_list : List [torch .Tensor ] = []
1657
1657
stride_per_key_per_rank : List [List [int ]] = []
1658
1658
stride : Optional [int ] = None
1659
+ inv_idx_keys : List [str ] = []
1660
+ inv_idx_tensors : List [torch .Tensor ] = []
1661
+
1659
1662
variable_stride_per_key_list = [
1660
1663
kjt .variable_stride_per_key () for kjt in kjt_list
1661
1664
]
@@ -1664,7 +1667,7 @@ def concat(
1664
1667
), "variable stride per key must be consistent for all KJTs"
1665
1668
variable_stride_per_key = all (variable_stride_per_key_list )
1666
1669
1667
- for kjt in kjt_list :
1670
+ for i , kjt in enumerate ( kjt_list ) :
1668
1671
curr_is_weighted : bool = kjt .weights_or_none () is not None
1669
1672
if is_weighted != curr_is_weighted :
1670
1673
raise ValueError ("Can't merge weighted KJT with unweighted KJT" )
@@ -1686,6 +1689,16 @@ def concat(
1686
1689
stride = kjt .stride ()
1687
1690
else :
1688
1691
assert stride == kjt .stride (), "strides must be consistent for all KJTs"
1692
+ if kjt .inverse_indices_or_none () is not None :
1693
+ assert (
1694
+ len (inv_idx_tensors ) == i
1695
+ ), "inverse indices must be consistent for all KJTs"
1696
+ inv_idx_keys += kjt .inverse_indices ()[0 ]
1697
+ inv_idx_tensors .append (kjt .inverse_indices ()[1 ])
1698
+ else :
1699
+ assert (
1700
+ len (inv_idx_tensors ) == 0
1701
+ ), "inverse indices must be consistent for all KJTs"
1689
1702
1690
1703
return KeyedJaggedTensor (
1691
1704
keys = keys ,
@@ -1697,6 +1710,11 @@ def concat(
1697
1710
stride_per_key_per_rank if variable_stride_per_key else None
1698
1711
),
1699
1712
length_per_key = length_per_key if has_length_per_key else None ,
1713
+ inverse_indices = (
1714
+ (inv_idx_keys , torch .cat (inv_idx_tensors ))
1715
+ if len (inv_idx_tensors ) == len (kjt_list )
1716
+ else None
1717
+ ),
1700
1718
)
1701
1719
1702
1720
@staticmethod
0 commit comments