@@ -115,6 +115,22 @@ def _get_weights_or_throw(weights: Optional[torch.Tensor]) -> torch.Tensor:
115
115
return weights
116
116
117
117
118
+ def _get_lengths_offset_per_key_or_throw (
119
+ lengths_offset_per_key : Optional [List [int ]],
120
+ ) -> List [int ]:
121
+ assert (
122
+ lengths_offset_per_key is not None
123
+ ), "This (Keyed)JaggedTensor doesn't have lengths_offset_per_key."
124
+ return lengths_offset_per_key
125
+
126
+
127
+ def _get_stride_per_key_or_throw (stride_per_key : Optional [List [int ]]) -> List [int ]:
128
+ assert (
129
+ stride_per_key is not None
130
+ ), "This (Keyed)JaggedTensor doesn't have stride_per_key."
131
+ return stride_per_key
132
+
133
+
118
134
def _get_inverse_indices_or_throw (
119
135
inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]],
120
136
) -> Tuple [List [str ], torch .Tensor ]:
@@ -891,9 +907,9 @@ def _jt_flatten_spec(t: JaggedTensor, spec: TreeSpec) -> List[Optional[torch.Ten
891
907
892
908
893
909
def _assert_tensor_has_no_elements_or_has_integers (
894
- tensor : torch .Tensor , tensor_name : str
910
+ tensor : Optional [ torch .Tensor ] , tensor_name : str
895
911
) -> None :
896
- if is_torchdynamo_compiling ():
912
+ if is_torchdynamo_compiling () or tensor is None :
897
913
# Skipping the check tensor.numel() == 0 to not guard on pt2 symbolic shapes.
898
914
# TODO(ivankobzarev): Use guard_size_oblivious to pass tensor.numel() == 0 once it is torch scriptable.
899
915
return
@@ -921,10 +937,13 @@ def _maybe_compute_stride_kjt(
921
937
stride : Optional [int ],
922
938
lengths : Optional [torch .Tensor ],
923
939
offsets : Optional [torch .Tensor ],
940
+ stride_per_key_per_rank : Optional [List [List [int ]]],
924
941
) -> int :
925
942
if stride is None :
926
943
if len (keys ) == 0 :
927
944
stride = 0
945
+ elif stride_per_key_per_rank is not None and len (stride_per_key_per_rank ) > 0 :
946
+ stride = max ([sum (s ) for s in stride_per_key_per_rank ])
928
947
elif offsets is not None and offsets .numel () > 0 :
929
948
stride = (offsets .numel () - 1 ) // len (keys )
930
949
elif lengths is not None :
@@ -1467,6 +1486,50 @@ def _check_attributes(
1467
1486
return True
1468
1487
1469
1488
1489
+ def _maybe_compute_lengths_offset_per_key (
1490
+ lengths_offset_per_key : Optional [List [int ]],
1491
+ stride_per_key : Optional [List [int ]],
1492
+ stride : Optional [int ],
1493
+ keys : List [str ],
1494
+ ) -> Optional [List [int ]]:
1495
+ if lengths_offset_per_key is not None :
1496
+ return lengths_offset_per_key
1497
+ elif stride_per_key is not None :
1498
+ return _cumsum (stride_per_key )
1499
+ elif stride is not None :
1500
+ return _cumsum ([stride ] * len (keys ))
1501
+ else :
1502
+ return None
1503
+
1504
+
1505
+ def _maybe_compute_stride_per_key (
1506
+ stride_per_key : Optional [List [int ]],
1507
+ stride_per_key_per_rank : Optional [List [List [int ]]],
1508
+ stride : Optional [int ],
1509
+ keys : List [str ],
1510
+ ) -> Optional [List [int ]]:
1511
+ if stride_per_key is not None :
1512
+ return stride_per_key
1513
+ elif stride_per_key_per_rank is not None :
1514
+ return [sum (s ) for s in stride_per_key_per_rank ]
1515
+ elif stride is not None :
1516
+ return [stride ] * len (keys )
1517
+ else :
1518
+ return None
1519
+
1520
+
1521
+ def _maybe_compute_variable_stride_per_key (
1522
+ variable_stride_per_key : Optional [bool ],
1523
+ stride_per_key_per_rank : Optional [List [List [int ]]],
1524
+ ) -> bool :
1525
+ if variable_stride_per_key is not None :
1526
+ return variable_stride_per_key
1527
+ elif stride_per_key_per_rank is not None :
1528
+ return True
1529
+ else :
1530
+ return False
1531
+
1532
+
1470
1533
class KeyedJaggedTensor (Pipelineable , metaclass = JaggedTensorMeta ):
1471
1534
"""Represents an (optionally weighted) keyed jagged tensor.
1472
1535
@@ -1540,62 +1603,57 @@ def __init__(
1540
1603
stride : Optional [int ] = None ,
1541
1604
stride_per_key_per_rank : Optional [List [List [int ]]] = None ,
1542
1605
# Below exposed to ensure torch.script-able
1606
+ stride_per_key : Optional [List [int ]] = None ,
1543
1607
length_per_key : Optional [List [int ]] = None ,
1608
+ lengths_offset_per_key : Optional [List [int ]] = None ,
1544
1609
offset_per_key : Optional [List [int ]] = None ,
1545
1610
index_per_key : Optional [Dict [str , int ]] = None ,
1546
1611
jt_dict : Optional [Dict [str , JaggedTensor ]] = None ,
1547
1612
inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = None ,
1548
1613
) -> None :
1614
+ """
1615
+ This is the constructor for KeyedJaggedTensor is jit.scriptable and PT2 compatible.
1616
+ It is important only to assign attributes here or do input checks to support various
1617
+ internal inference optimizations. By convention the attirbute is named same as input arg, just
1618
+ with leading underscore
1619
+ """
1549
1620
self ._keys : List [str ] = keys
1550
1621
self ._values : torch .Tensor = values
1551
1622
self ._weights : Optional [torch .Tensor ] = weights
1552
- if offsets is not None :
1553
- _assert_tensor_has_no_elements_or_has_integers (offsets , "offsets" )
1554
- if lengths is not None :
1555
- _assert_tensor_has_no_elements_or_has_integers (lengths , "lengths" )
1556
1623
self ._lengths : Optional [torch .Tensor ] = lengths
1557
1624
self ._offsets : Optional [torch .Tensor ] = offsets
1558
-
1559
- self ._stride_per_key_per_rank : List [List [int ]] = []
1560
- self ._stride_per_key : List [int ] = []
1561
- self ._variable_stride_per_key : bool = False
1562
- self ._stride : int = - 1
1563
-
1564
- if stride_per_key_per_rank is not None :
1565
- self ._stride_per_key_per_rank = stride_per_key_per_rank
1566
- self ._stride_per_key = [sum (s ) for s in self ._stride_per_key_per_rank ]
1567
- self ._variable_stride_per_key = True
1568
- if stride is not None :
1569
- self ._stride = stride
1570
- else :
1571
- self ._stride = (
1572
- max (self ._stride_per_key ) if len (self ._stride_per_key ) > 0 else 0
1573
- )
1574
- else :
1575
- stride = _maybe_compute_stride_kjt (keys , stride , lengths , offsets )
1576
- self ._stride = stride
1577
- self ._stride_per_key_per_rank = [[stride ]] * len (self ._keys )
1578
- self ._stride_per_key = [sum (s ) for s in self ._stride_per_key_per_rank ]
1579
-
1580
- # lazy fields
1625
+ self ._stride : Optional [int ] = stride
1626
+ self ._stride_per_key_per_rank : Optional [List [List [int ]]] = (
1627
+ stride_per_key_per_rank
1628
+ )
1629
+ self ._stride_per_key : Optional [List [int ]] = stride_per_key
1581
1630
self ._length_per_key : Optional [List [int ]] = length_per_key
1582
1631
self ._offset_per_key : Optional [List [int ]] = offset_per_key
1632
+ self ._lengths_offset_per_key : Optional [List [int ]] = lengths_offset_per_key
1583
1633
self ._index_per_key : Optional [Dict [str , int ]] = index_per_key
1584
1634
self ._jt_dict : Optional [Dict [str , JaggedTensor ]] = jt_dict
1585
1635
self ._inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = (
1586
1636
inverse_indices
1587
1637
)
1588
- self ._lengths_offset_per_key : List [int ] = []
1589
1638
1590
- self ._init_pt2_checks ()
1639
+ # legacy attribute, for backward compatabilibity
1640
+ self ._variable_stride_per_key : Optional [bool ] = None
1641
+
1642
+ # validation logic
1643
+ if not torch .jit .is_scripting ():
1644
+ _assert_tensor_has_no_elements_or_has_integers (offsets , "offsets" )
1645
+ _assert_tensor_has_no_elements_or_has_integers (lengths , "lengths" )
1646
+ self ._init_pt2_checks ()
1591
1647
1592
1648
def _init_pt2_checks (self ) -> None :
1593
1649
if torch .jit .is_scripting () or not is_torchdynamo_compiling ():
1594
1650
return
1595
-
1596
- pt2_checks_all_is_size (self ._stride_per_key )
1597
- for s in self ._stride_per_key_per_rank :
1598
- pt2_checks_all_is_size (s )
1651
+ if self ._stride_per_key is not None :
1652
+ pt2_checks_all_is_size (self ._stride_per_key )
1653
+ if self ._stride_per_key_per_rank is not None :
1654
+ # pyre-ignore [16]
1655
+ for s in self ._stride_per_key_per_rank :
1656
+ pt2_checks_all_is_size (s )
1599
1657
1600
1658
@staticmethod
1601
1659
def from_offsets_sync (
@@ -1863,16 +1921,34 @@ def weights_or_none(self) -> Optional[torch.Tensor]:
1863
1921
return self ._weights
1864
1922
1865
1923
def stride (self ) -> int :
1866
- return self ._stride
1924
+ stride = _maybe_compute_stride_kjt (
1925
+ self ._keys ,
1926
+ self ._stride ,
1927
+ self ._lengths ,
1928
+ self ._offsets ,
1929
+ self ._stride_per_key_per_rank ,
1930
+ )
1931
+ self ._stride = stride
1932
+ return stride
1867
1933
1868
1934
def stride_per_key (self ) -> List [int ]:
1869
- return self ._stride_per_key
1935
+ stride_per_key = _maybe_compute_stride_per_key (
1936
+ self ._stride_per_key ,
1937
+ self ._stride_per_key_per_rank ,
1938
+ self .stride (),
1939
+ self ._keys ,
1940
+ )
1941
+ self ._stride_per_key = stride_per_key
1942
+ return _get_stride_per_key_or_throw (stride_per_key )
1870
1943
1871
1944
def stride_per_key_per_rank (self ) -> List [List [int ]]:
1872
- return self ._stride_per_key_per_rank
1945
+ stride_per_key_per_rank = self ._stride_per_key_per_rank
1946
+ return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
1873
1947
1874
1948
def variable_stride_per_key (self ) -> bool :
1875
- return self ._variable_stride_per_key
1949
+ if self ._variable_stride_per_key is not None :
1950
+ return self ._variable_stride_per_key
1951
+ return self ._stride_per_key_per_rank is not None
1876
1952
1877
1953
def inverse_indices (self ) -> Tuple [List [str ], torch .Tensor ]:
1878
1954
return _get_inverse_indices_or_throw (self ._inverse_indices )
@@ -1925,9 +2001,20 @@ def offset_per_key_or_none(self) -> Optional[List[int]]:
1925
2001
return self ._offset_per_key
1926
2002
1927
2003
def lengths_offset_per_key (self ) -> List [int ]:
1928
- if not self ._lengths_offset_per_key :
1929
- self ._lengths_offset_per_key = _cumsum (self .stride_per_key ())
1930
- return self ._lengths_offset_per_key
2004
+ if self .variable_stride_per_key ():
2005
+ _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key (
2006
+ self ._lengths_offset_per_key ,
2007
+ self .stride_per_key (),
2008
+ None ,
2009
+ self ._keys ,
2010
+ )
2011
+ else :
2012
+ _lengths_offset_per_key = _maybe_compute_lengths_offset_per_key (
2013
+ self ._lengths_offset_per_key , None , self .stride (), self ._keys
2014
+ )
2015
+
2016
+ self ._lengths_offset_per_key = _lengths_offset_per_key
2017
+ return _get_lengths_offset_per_key_or_throw (_lengths_offset_per_key )
1931
2018
1932
2019
def index_per_key (self ) -> Dict [str , int ]:
1933
2020
return self ._key_indices ()
@@ -1958,7 +2045,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
1958
2045
offsets = self ._offsets ,
1959
2046
stride = self ._stride ,
1960
2047
stride_per_key_per_rank = stride_per_key_per_rank ,
2048
+ stride_per_key = None ,
1961
2049
length_per_key = self ._length_per_key ,
2050
+ lengths_offset_per_key = None ,
1962
2051
offset_per_key = self ._offset_per_key ,
1963
2052
index_per_key = self ._index_per_key ,
1964
2053
jt_dict = self ._jt_dict ,
@@ -1992,7 +2081,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
1992
2081
),
1993
2082
stride = self ._stride ,
1994
2083
stride_per_key_per_rank = stride_per_key_per_rank ,
2084
+ stride_per_key = None ,
1995
2085
length_per_key = None ,
2086
+ lengths_offset_per_key = None ,
1996
2087
offset_per_key = None ,
1997
2088
index_per_key = None ,
1998
2089
jt_dict = None ,
@@ -2036,7 +2127,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2036
2127
offsets = None ,
2037
2128
stride = self ._stride ,
2038
2129
stride_per_key_per_rank = stride_per_key_per_rank ,
2130
+ stride_per_key = None ,
2039
2131
length_per_key = split_length_per_key ,
2132
+ lengths_offset_per_key = None ,
2040
2133
offset_per_key = None ,
2041
2134
index_per_key = None ,
2042
2135
jt_dict = None ,
@@ -2070,7 +2163,9 @@ def split(self, segments: List[int]) -> List["KeyedJaggedTensor"]:
2070
2163
offsets = None ,
2071
2164
stride = self ._stride ,
2072
2165
stride_per_key_per_rank = stride_per_key_per_rank ,
2166
+ stride_per_key = None ,
2073
2167
length_per_key = split_length_per_key ,
2168
+ lengths_offset_per_key = None ,
2074
2169
offset_per_key = None ,
2075
2170
index_per_key = None ,
2076
2171
jt_dict = None ,
@@ -2098,10 +2193,11 @@ def permute(
2098
2193
for index in indices :
2099
2194
key = self .keys ()[index ]
2100
2195
permuted_keys .append (key )
2101
- permuted_stride_per_key_per_rank .append (
2102
- self .stride_per_key_per_rank ()[index ]
2103
- )
2104
2196
permuted_length_per_key .append (length_per_key [index ])
2197
+ if self .variable_stride_per_key ():
2198
+ permuted_stride_per_key_per_rank .append (
2199
+ self .stride_per_key_per_rank ()[index ]
2200
+ )
2105
2201
2106
2202
permuted_length_per_key_sum = sum (permuted_length_per_key )
2107
2203
if not torch .jit .is_scripting () and is_non_strict_exporting ():
@@ -2164,7 +2260,9 @@ def permute(
2164
2260
offsets = None ,
2165
2261
stride = self ._stride ,
2166
2262
stride_per_key_per_rank = stride_per_key_per_rank ,
2263
+ stride_per_key = None ,
2167
2264
length_per_key = permuted_length_per_key if len (permuted_keys ) > 0 else None ,
2265
+ lengths_offset_per_key = None ,
2168
2266
offset_per_key = None ,
2169
2267
index_per_key = None ,
2170
2268
jt_dict = None ,
@@ -2184,7 +2282,9 @@ def flatten_lengths(self) -> "KeyedJaggedTensor":
2184
2282
offsets = None ,
2185
2283
stride = self ._stride ,
2186
2284
stride_per_key_per_rank = stride_per_key_per_rank ,
2285
+ stride_per_key = None ,
2187
2286
length_per_key = self .length_per_key (),
2287
+ lengths_offset_per_key = None ,
2188
2288
offset_per_key = None ,
2189
2289
index_per_key = None ,
2190
2290
jt_dict = None ,
@@ -2304,8 +2404,10 @@ def to(
2304
2404
self ._stride_per_key_per_rank if self .variable_stride_per_key () else None
2305
2405
)
2306
2406
length_per_key = self ._length_per_key
2407
+ lengths_offset_per_key = self ._lengths_offset_per_key
2307
2408
offset_per_key = self ._offset_per_key
2308
2409
index_per_key = self ._index_per_key
2410
+ stride_per_key = self ._stride_per_key
2309
2411
jt_dict = self ._jt_dict
2310
2412
inverse_indices = self ._inverse_indices
2311
2413
if inverse_indices is not None :
@@ -2337,7 +2439,9 @@ def to(
2337
2439
),
2338
2440
stride = self ._stride ,
2339
2441
stride_per_key_per_rank = stride_per_key_per_rank ,
2442
+ stride_per_key = stride_per_key ,
2340
2443
length_per_key = length_per_key ,
2444
+ lengths_offset_per_key = lengths_offset_per_key ,
2341
2445
offset_per_key = offset_per_key ,
2342
2446
index_per_key = index_per_key ,
2343
2447
jt_dict = jt_dict ,
@@ -2387,7 +2491,9 @@ def pin_memory(self) -> "KeyedJaggedTensor":
2387
2491
offsets = offsets .pin_memory () if offsets is not None else None ,
2388
2492
stride = self ._stride ,
2389
2493
stride_per_key_per_rank = stride_per_key_per_rank ,
2494
+ stride_per_key = self ._stride_per_key ,
2390
2495
length_per_key = self ._length_per_key ,
2496
+ lengths_offset_per_key = self ._lengths_offset_per_key ,
2391
2497
offset_per_key = self ._offset_per_key ,
2392
2498
index_per_key = self ._index_per_key ,
2393
2499
jt_dict = None ,
0 commit comments