@@ -1729,39 +1729,82 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
1729
1729
else :
1730
1730
split_length_per_key = _length_per_key [start :end ]
1731
1731
1732
- if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1733
- # Checks for dynamo dynamic shapes tracing
1734
- torch ._check_is_size (start_offset )
1735
- torch ._check_is_size (end_offset )
1736
- torch ._check_is_size (end_offset - start_offset )
1732
+ if is_non_strict_exporting ():
1733
+ sz = sum (split_length_per_key )
1734
+
1735
+ torch ._check_is_size (sz )
1737
1736
torch ._check (start_offset <= self ._values .size (0 ))
1738
- torch ._check (end_offset <= self ._values .size (0 ))
1739
- torch ._check ( end_offset >= start_offset )
1737
+ torch ._check (sz <= self ._values .size (0 ))
1738
+ torch ._check_is_size ( start_offset )
1740
1739
1741
- split_list .append (
1742
- KeyedJaggedTensor (
1743
- keys = keys ,
1744
- values = self ._values [start_offset :end_offset ],
1745
- weights = (
1746
- None
1747
- if self .weights_or_none () is None
1748
- else self .weights ()[start_offset :end_offset ]
1749
- ),
1750
- lengths = self .lengths ()[
1751
- self .lengths_offset_per_key ()[
1752
- start
1753
- ] : self .lengths_offset_per_key ()[end ]
1754
- ],
1755
- offsets = None ,
1756
- stride = stride ,
1757
- stride_per_key_per_rank = stride_per_key_per_rank ,
1758
- length_per_key = split_length_per_key ,
1759
- offset_per_key = None ,
1760
- index_per_key = None ,
1761
- jt_dict = None ,
1762
- inverse_indices = None ,
1740
+ # Why are below 3 needed?
1741
+ torch ._check (sz != 0 )
1742
+ torch ._check (sz != - 1 )
1743
+ torch ._check (sz != 1 )
1744
+
1745
+ torch ._check (start_offset + sz <= self ._values .size (0 ))
1746
+ torch ._check (start_offset + sz >= 0 )
1747
+
1748
+ lengths_start = self .lengths_offset_per_key ()[start ]
1749
+ lengths_sz = self .lengths_offset_per_key ()[end ] - lengths_start
1750
+
1751
+ _lengths = torch .narrow (
1752
+ self .lengths (), 0 , lengths_start , lengths_sz
1753
+ )
1754
+ split_list .append (
1755
+ KeyedJaggedTensor (
1756
+ keys = keys ,
1757
+ values = torch .narrow (self ._values , 0 , start_offset , sz ),
1758
+ weights = (
1759
+ None
1760
+ if self .weights_or_none () is None
1761
+ else torch .narrow (self .weights (), 0 , start_offset , sz )
1762
+ ),
1763
+ lengths = _lengths ,
1764
+ offsets = None ,
1765
+ stride = stride ,
1766
+ stride_per_key_per_rank = stride_per_key_per_rank ,
1767
+ length_per_key = split_length_per_key ,
1768
+ offset_per_key = None ,
1769
+ index_per_key = None ,
1770
+ jt_dict = None ,
1771
+ inverse_indices = None ,
1772
+ )
1773
+ )
1774
+ else :
1775
+ if not torch .jit .is_scripting () and is_torchdynamo_compiling ():
1776
+ # Checks for dynamo dynamic shapes tracing
1777
+ torch ._check_is_size (start_offset )
1778
+ torch ._check_is_size (end_offset )
1779
+ torch ._check_is_size (end_offset - start_offset )
1780
+ torch ._check (start_offset <= self ._values .size (0 ))
1781
+ torch ._check (end_offset <= self ._values .size (0 ))
1782
+ torch ._check (end_offset >= start_offset )
1783
+
1784
+ split_list .append (
1785
+ KeyedJaggedTensor (
1786
+ keys = keys ,
1787
+ values = self ._values [start_offset :end_offset ],
1788
+ weights = (
1789
+ None
1790
+ if self .weights_or_none () is None
1791
+ else self .weights ()[start_offset :end_offset ]
1792
+ ),
1793
+ lengths = self .lengths ()[
1794
+ self .lengths_offset_per_key ()[
1795
+ start
1796
+ ] : self .lengths_offset_per_key ()[end ]
1797
+ ],
1798
+ offsets = None ,
1799
+ stride = stride ,
1800
+ stride_per_key_per_rank = stride_per_key_per_rank ,
1801
+ length_per_key = split_length_per_key ,
1802
+ offset_per_key = None ,
1803
+ index_per_key = None ,
1804
+ jt_dict = None ,
1805
+ inverse_indices = None ,
1806
+ )
1763
1807
)
1764
- )
1765
1808
start = end
1766
1809
start_offset = end_offset
1767
1810
return split_list
0 commit comments