12
12
import logging
13
13
import warnings
14
14
from collections import defaultdict , deque , OrderedDict
15
- from dataclasses import dataclass , field
16
15
from itertools import accumulate
17
16
from typing import Any , cast , Dict , List , MutableMapping , Optional , Tuple , Type , Union
18
17
72
71
EmbeddingCollection ,
73
72
EmbeddingCollectionInterface ,
74
73
)
75
- from torchrec .modules .utils import construct_jagged_tensors
74
+ from torchrec .modules .utils import construct_jagged_tensors , SequenceVBEContext
76
75
from torchrec .optim .fused import EmptyFusedOptimizer , FusedOptimizerModule
77
76
from torchrec .optim .keyed import CombinedOptimizer , KeyedOptimizer
78
- from torchrec .sparse .jagged_tensor import JaggedTensor , KeyedJaggedTensor
77
+ from torchrec .sparse .jagged_tensor import (
78
+ _pin_and_move ,
79
+ _to_offsets ,
80
+ JaggedTensor ,
81
+ KeyedJaggedTensor ,
82
+ )
79
83
80
84
try :
81
85
torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops" )
@@ -323,6 +327,30 @@ def create_sharding_infos_by_sharding_device_group(
323
327
return sharding_type_device_group_to_sharding_infos
324
328
325
329
330
+ def pad_vbe_kjt_lengths (features : KeyedJaggedTensor ) -> KeyedJaggedTensor :
331
+ max_stride = max (features .stride_per_key ())
332
+ new_lengths = torch .zeros (
333
+ max_stride * len (features .keys ()),
334
+ device = features .device (),
335
+ dtype = features .lengths ().dtype ,
336
+ )
337
+ cum_stride = 0
338
+ for i , stride in enumerate (features .stride_per_key ()):
339
+ new_lengths [i * max_stride : i * max_stride + stride ] = features .lengths ()[
340
+ cum_stride : cum_stride + stride
341
+ ]
342
+ cum_stride += stride
343
+
344
+ return KeyedJaggedTensor (
345
+ keys = features .keys (),
346
+ values = features .values (),
347
+ lengths = new_lengths ,
348
+ stride = max_stride ,
349
+ length_per_key = features .length_per_key (),
350
+ offset_per_key = features .offset_per_key (),
351
+ )
352
+
353
+
326
354
class EmbeddingCollectionContext (Multistreamable ):
327
355
# Torch Dynamo does not support default_factory=list:
328
356
# https://github.com/pytorch/pytorch/issues/120108
@@ -333,11 +361,13 @@ def __init__(
333
361
sharding_contexts : Optional [List [SequenceShardingContext ]] = None ,
334
362
input_features : Optional [List [KeyedJaggedTensor ]] = None ,
335
363
reverse_indices : Optional [List [torch .Tensor ]] = None ,
364
+ seq_vbe_ctx : Optional [List [SequenceVBEContext ]] = None ,
336
365
) -> None :
337
366
super ().__init__ ()
338
367
self .sharding_contexts : List [SequenceShardingContext ] = sharding_contexts or []
339
368
self .input_features : List [KeyedJaggedTensor ] = input_features or []
340
369
self .reverse_indices : List [torch .Tensor ] = reverse_indices or []
370
+ self .seq_vbe_ctx : List [SequenceVBEContext ] = seq_vbe_ctx or []
341
371
342
372
def record_stream (self , stream : torch .cuda .streams .Stream ) -> None :
343
373
for ctx in self .sharding_contexts :
@@ -346,6 +376,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
346
376
f .record_stream (stream )
347
377
for r in self .reverse_indices :
348
378
r .record_stream (stream )
379
+ for s in self .seq_vbe_ctx :
380
+ s .record_stream (stream )
349
381
350
382
351
383
class EmbeddingCollectionAwaitable (LazyAwaitable [Dict [str , JaggedTensor ]]):
@@ -385,6 +417,9 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
385
417
if i >= len (self ._ctx .reverse_indices )
386
418
else self ._ctx .reverse_indices [i ]
387
419
)
420
+ seq_vbe_ctx = (
421
+ None if i >= len (self ._ctx .seq_vbe_ctx ) else self ._ctx .seq_vbe_ctx [i ]
422
+ )
388
423
jt_dict .update (
389
424
construct_jagged_tensors (
390
425
embeddings = w .wait (),
@@ -394,6 +429,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
394
429
features_to_permute_indices = self ._features_to_permute_indices ,
395
430
original_features = original_features ,
396
431
reverse_indices = reverse_indices ,
432
+ seq_vbe_ctx = seq_vbe_ctx ,
397
433
)
398
434
)
399
435
return jt_dict
@@ -506,6 +542,7 @@ def __init__(
506
542
module .embedding_configs (), table_name_to_parameter_sharding
507
543
)
508
544
self ._need_indices : bool = module .need_indices ()
545
+ self ._inverse_indices_permute_per_sharding : Optional [List [torch .Tensor ]] = None
509
546
510
547
for index , (sharding , lookup ) in enumerate (
511
548
zip (
@@ -847,11 +884,9 @@ def _create_output_dist(
847
884
848
885
def _dedup_indices (
849
886
self ,
850
- input_feature_splits : List [KeyedJaggedTensor ],
851
887
ctx : EmbeddingCollectionContext ,
888
+ input_feature_splits : List [KeyedJaggedTensor ],
852
889
) -> List [KeyedJaggedTensor ]:
853
- if not self ._use_index_dedup :
854
- return input_feature_splits
855
890
with record_function ("## dedup_ec_indices ##" ):
856
891
features_by_shards = []
857
892
for i , input_feature in enumerate (input_feature_splits ):
@@ -881,30 +916,107 @@ def _dedup_indices(
881
916
882
917
return features_by_shards
883
918
919
+ def _create_inverse_indices_permute_per_sharding (
920
+ self , inverse_indices : Tuple [List [str ], torch .Tensor ]
921
+ ) -> None :
922
+ if (
923
+ len (self ._embedding_names_per_sharding ) == 1
924
+ and self ._embedding_names_per_sharding [0 ] == inverse_indices [0 ]
925
+ ):
926
+ return
927
+ index_per_name = {name : i for i , name in enumerate (inverse_indices [0 ])}
928
+ permute_per_sharding = []
929
+ for emb_names in self ._embedding_names_per_sharding :
930
+ permute = _pin_and_move (
931
+ torch .tensor (
932
+ [index_per_name [name .split ("@" )[0 ]] for name in emb_names ]
933
+ ),
934
+ inverse_indices [1 ].device ,
935
+ )
936
+ permute_per_sharding .append (permute )
937
+ self ._inverse_indices_permute_per_sharding = permute_per_sharding
938
+
939
+ def _compute_sequence_vbe_context (
940
+ self ,
941
+ ctx : EmbeddingCollectionContext ,
942
+ unpadded_features : KeyedJaggedTensor ,
943
+ ) -> None :
944
+ assert (
945
+ unpadded_features .inverse_indices_or_none () is not None
946
+ ), "inverse indices must be provided from KJT if using variable batch size per feature."
947
+
948
+ inverse_indices = unpadded_features .inverse_indices ()
949
+ stride = inverse_indices [1 ].numel () // len (inverse_indices [0 ])
950
+ if self ._inverse_indices_permute_per_sharding is None :
951
+ self ._create_inverse_indices_permute_per_sharding (inverse_indices )
952
+
953
+ if self ._features_order :
954
+ unpadded_features = unpadded_features .permute (
955
+ self ._features_order ,
956
+ self ._features_order_tensor ,
957
+ )
958
+
959
+ features_by_sharding = unpadded_features .split (self ._feature_splits )
960
+ for i , feature in enumerate (features_by_sharding ):
961
+ if self ._inverse_indices_permute_per_sharding is not None :
962
+ permute = self ._inverse_indices_permute_per_sharding [i ]
963
+ permuted_indices = torch .index_select (inverse_indices [1 ], 0 , permute )
964
+ else :
965
+ permuted_indices = inverse_indices [1 ]
966
+ stride_per_key = _pin_and_move (
967
+ torch .tensor (feature .stride_per_key ()), feature .device ()
968
+ )
969
+ offsets = _to_offsets (stride_per_key )[:- 1 ].unsqueeze (- 1 )
970
+ recat = (permuted_indices + offsets ).flatten ().int ()
971
+
972
+ if self ._need_indices :
973
+ reindexed_lengths , reindexed_values , _ = (
974
+ torch .ops .fbgemm .permute_1D_sparse_data (
975
+ recat ,
976
+ feature .lengths (),
977
+ feature .values (),
978
+ )
979
+ )
980
+ else :
981
+ reindexed_lengths = torch .index_select (feature .lengths (), 0 , recat )
982
+ reindexed_values = None
983
+
984
+ reindexed_lengths = reindexed_lengths .view (- 1 , stride )
985
+ reindexed_length_per_key = torch .sum (reindexed_lengths , dim = 1 ).tolist ()
986
+
987
+ ctx .seq_vbe_ctx .append (
988
+ SequenceVBEContext (
989
+ recat = recat ,
990
+ unpadded_lengths = feature .lengths (),
991
+ reindexed_lengths = reindexed_lengths ,
992
+ reindexed_length_per_key = reindexed_length_per_key ,
993
+ reindexed_values = reindexed_values ,
994
+ )
995
+ )
996
+
884
997
# pyre-ignore [14]
885
998
def input_dist (
886
999
self ,
887
1000
ctx : EmbeddingCollectionContext ,
888
1001
features : KeyedJaggedTensor ,
889
1002
) -> Awaitable [Awaitable [KJTList ]]:
890
- if features .variable_stride_per_key ():
891
- raise ValueError (
892
- "Variable batch per feature is not supported with EmbeddingCollection"
893
- )
894
1003
if self ._has_uninitialized_input_dist :
895
1004
self ._create_input_dist (input_feature_names = features .keys ())
896
1005
self ._has_uninitialized_input_dist = False
897
1006
with torch .no_grad ():
1007
+ unpadded_features = None
1008
+ if features .variable_stride_per_key ():
1009
+ unpadded_features = features
1010
+ features = pad_vbe_kjt_lengths (unpadded_features )
1011
+
898
1012
if self ._features_order :
899
1013
features = features .permute (
900
1014
self ._features_order ,
901
1015
self ._features_order_tensor ,
902
1016
)
903
-
904
- input_feature_splits = features .split (
905
- self ._feature_splits ,
906
- )
907
- features_by_shards = self ._dedup_indices (input_feature_splits , ctx )
1017
+ features_by_shards = features .split (self ._feature_splits )
1018
+ if self ._use_index_dedup :
1019
+ features_by_shards = self ._dedup_indices (ctx , features_by_shards )
908
1020
909
1021
awaitables = []
910
1022
for input_dist , features in zip (self ._input_dists , features_by_shards ):
@@ -919,6 +1031,8 @@ def input_dist(
919
1031
),
920
1032
)
921
1033
)
1034
+ if unpadded_features is not None :
1035
+ self ._compute_sequence_vbe_context (ctx , unpadded_features )
922
1036
return KJTListSplitsAwaitable (awaitables , ctx )
923
1037
924
1038
def compute (
0 commit comments