@@ -1831,15 +1831,16 @@ def permute(
1831
1831
permuted_keys : List [str ] = []
1832
1832
permuted_stride_per_key_per_rank : List [List [int ]] = []
1833
1833
permuted_length_per_key : List [int ] = []
1834
- permuted_lengths_sum = 0
1834
+ permuted_length_per_key_sum = 0
1835
1835
for index in indices :
1836
1836
key = self .keys ()[index ]
1837
1837
permuted_keys .append (key )
1838
1838
permuted_stride_per_key_per_rank .append (
1839
1839
self .stride_per_key_per_rank ()[index ]
1840
1840
)
1841
1841
permuted_length_per_key .append (length_per_key [index ])
1842
- permuted_lengths_sum += length_per_key [index ]
1842
+ if not is_non_strict_exporting ():
1843
+ permuted_length_per_key_sum += length_per_key [index ]
1843
1844
if self .variable_stride_per_key ():
1844
1845
length_per_key_tensor = _pin_and_move (
1845
1846
torch .tensor (self .length_per_key ()), self .device ()
@@ -1860,6 +1861,19 @@ def permute(
1860
1861
self .weights_or_none (),
1861
1862
)
1862
1863
else :
1864
+ if not torch .jit .is_scripting () and is_non_strict_exporting ():
1865
+ permuted_length_per_key_sum = torch .sum (
1866
+ torch ._refs .tensor (
1867
+ permuted_length_per_key ,
1868
+ dtype = torch .int32 ,
1869
+ device = torch .device ("cpu" ),
1870
+ pin_memory = False ,
1871
+ requires_grad = False ,
1872
+ )
1873
+ ).item ()
1874
+
1875
+ torch ._check (permuted_length_per_key_sum > 0 )
1876
+
1863
1877
(
1864
1878
permuted_lengths ,
1865
1879
permuted_values ,
@@ -1869,7 +1883,7 @@ def permute(
1869
1883
self .lengths ().view (len (self ._keys ), - 1 ),
1870
1884
self .values (),
1871
1885
self .weights_or_none (),
1872
- permuted_lengths_sum ,
1886
+ permuted_length_per_key_sum ,
1873
1887
)
1874
1888
stride , optional_permuted_stride_per_key_per_rank = (
1875
1889
(None , permuted_stride_per_key_per_rank )
0 commit comments