8
8
# pyre-strict
9
9
10
10
import copy
11
- from collections import OrderedDict
11
+ from collections import defaultdict , OrderedDict
12
12
from dataclasses import dataclass , field
13
13
from typing import (
14
14
Any ,
15
+ Callable ,
15
16
cast ,
16
17
Dict ,
17
18
Iterator ,
27
28
import torch
28
29
from fbgemm_gpu .permute_pooled_embedding_modules import PermutePooledEmbeddings
29
30
from torch import nn , Tensor
31
+ from torch .autograd .profiler import record_function
30
32
from torch .nn .modules .module import _IncompatibleKeys
31
33
from torch .nn .parallel import DistributedDataParallel
32
34
from torchrec .distributed .embedding_sharding import (
79
81
)
80
82
from torchrec .optim .fused import EmptyFusedOptimizer , FusedOptimizerModule
81
83
from torchrec .optim .keyed import CombinedOptimizer , KeyedOptimizer
82
- from torchrec .sparse .jagged_tensor import KeyedJaggedTensor , KeyedTensor
84
+ from torchrec .sparse .jagged_tensor import _to_offsets , KeyedJaggedTensor , KeyedTensor
83
85
84
86
try :
85
87
torch .ops .load_library ("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops" )
@@ -378,6 +380,7 @@ class EmbeddingBagCollectionContext(Multistreamable):
378
380
)
379
381
inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = None
380
382
variable_batch_per_feature : bool = False
383
+ mean_pooling_callback : Optional [Callable [[KeyedTensor ], KeyedTensor ]] = None
381
384
382
385
def record_stream (self , stream : torch .cuda .streams .Stream ) -> None :
383
386
for ctx in self .sharding_contexts :
@@ -415,13 +418,22 @@ def __init__(
415
418
self ._embedding_bag_configs : List [EmbeddingBagConfig ] = (
416
419
module .embedding_bag_configs ()
417
420
)
418
- self ._table_names : List [str ] = [
419
- config .name for config in self ._embedding_bag_configs
420
- ]
421
421
422
- self ._table_name_to_config : Dict [str , EmbeddingBagConfig ] = {
423
- config .name : config for config in self ._embedding_bag_configs
424
- }
422
+ self ._table_names : List [str ] = []
423
+ self ._pooling_type_to_rs_features : Dict [str , List [str ]] = defaultdict (list )
424
+ self ._table_name_to_config : Dict [str , EmbeddingBagConfig ] = {}
425
+
426
+ for config in self ._embedding_bag_configs :
427
+ self ._table_names .append (config .name )
428
+ self ._table_name_to_config [config .name ] = config
429
+
430
+ if table_name_to_parameter_sharding [config .name ].sharding_type in [
431
+ ShardingType .TABLE_ROW_WISE .value ,
432
+ ShardingType .ROW_WISE .value ,
433
+ ]:
434
+ self ._pooling_type_to_rs_features [config .pooling .value ].extend (
435
+ config .feature_names
436
+ )
425
437
426
438
self .module_sharding_plan : EmbeddingModuleShardingPlan = cast (
427
439
EmbeddingModuleShardingPlan ,
@@ -472,6 +484,16 @@ def __init__(
472
484
self ._uncombined_embedding_names : List [str ] = []
473
485
self ._uncombined_embedding_dims : List [int ] = []
474
486
self ._inverse_indices_permute_indices : Optional [torch .Tensor ] = None
487
+ # to support mean pooling callback hook
488
+ self ._has_mean_pooling_callback : bool = (
489
+ True
490
+ if PoolingType .MEAN .value in self ._pooling_type_to_rs_features
491
+ else False
492
+ )
493
+ self ._dim_per_key : Optional [torch .Tensor ] = None
494
+ self ._kjt_key_indices : Dict [str , int ] = {}
495
+ self ._kjt_inverse_order : Optional [torch .Tensor ] = None
496
+ self ._kt_key_ordering : Optional [torch .Tensor ] = None
475
497
# to support the FP16 hook
476
498
self ._create_output_dist ()
477
499
@@ -720,6 +742,38 @@ def _create_input_dist(
720
742
persistent = False ,
721
743
)
722
744
745
+ def _init_mean_pooling_callback (
746
+ self ,
747
+ input_feature_names : List [str ],
748
+ inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]],
749
+ ) -> None :
750
+ # account for shared features
751
+ feature_names : List [str ] = [
752
+ feature_name
753
+ for sharding in self ._sharding_type_to_sharding .values ()
754
+ for feature_name in sharding .feature_names ()
755
+ ]
756
+
757
+ for i , key in enumerate (feature_names ):
758
+ if key not in self ._kjt_key_indices : # index of first occurence
759
+ self ._kjt_key_indices [key ] = i
760
+
761
+ keyed_tensor_ordering = []
762
+ for key in self ._embedding_names :
763
+ if "@" in key :
764
+ key = key .split ("@" )[0 ]
765
+ keyed_tensor_ordering .append (self ._kjt_key_indices [key ])
766
+ self ._kt_key_ordering = torch .tensor (keyed_tensor_ordering , device = self ._device )
767
+
768
+ if inverse_indices :
769
+ key_to_inverse_index = {
770
+ name : i for i , name in enumerate (inverse_indices [0 ])
771
+ }
772
+ self ._kjt_inverse_order = torch .tensor (
773
+ [key_to_inverse_index [key ] for key in feature_names ],
774
+ device = self ._device ,
775
+ )
776
+
723
777
def _create_lookups (
724
778
self ,
725
779
) -> None :
@@ -737,6 +791,7 @@ def _create_output_dist(self) -> None:
737
791
)
738
792
self ._uncombined_embedding_dims .extend (sharding .uncombined_embedding_dims ())
739
793
embedding_shard_metadata .extend (sharding .embedding_shard_metadata ())
794
+ self ._dim_per_key = torch .tensor (self ._embedding_dims , device = self ._device )
740
795
embedding_shard_offsets : List [int ] = [
741
796
meta .shard_offsets [1 ] if meta is not None else 0
742
797
for meta in embedding_shard_metadata
@@ -789,12 +844,31 @@ def input_dist(
789
844
self ._has_uninitialized_input_dist = False
790
845
if ctx .variable_batch_per_feature :
791
846
self ._create_inverse_indices_permute_indices (ctx .inverse_indices )
847
+ if self ._has_mean_pooling_callback :
848
+ self ._init_mean_pooling_callback (features .keys (), ctx .inverse_indices )
792
849
with torch .no_grad ():
793
850
if self ._has_features_permute :
794
851
features = features .permute (
795
852
self ._features_order ,
796
853
self ._features_order_tensor ,
797
854
)
855
+ if self ._has_mean_pooling_callback :
856
+ ctx .mean_pooling_callback = create_mean_pooling_callback (
857
+ lengths = features .lengths (),
858
+ stride = features .stride (),
859
+ keys = features .keys (),
860
+ pooling_type_to_rs_features = self ._pooling_type_to_rs_features ,
861
+ stride_per_key = features .stride_per_key (),
862
+ dim_per_key = self ._dim_per_key , # pyre-ignore[6]
863
+ embedding_names = self ._embedding_names ,
864
+ embedding_dims = self ._embedding_dims ,
865
+ variable_batch_per_feature = ctx .variable_batch_per_feature ,
866
+ kjt_inverse_order = self ._kjt_inverse_order , # pyre-ignore[6]
867
+ kjt_key_indices = self ._kjt_key_indices ,
868
+ kt_key_ordering = self ._kt_key_ordering , # pyre-ignore[6]
869
+ inverse_indices = ctx .inverse_indices ,
870
+ )
871
+
798
872
features_by_shards = features .split (
799
873
self ._feature_splits ,
800
874
)
@@ -840,7 +914,7 @@ def output_dist(
840
914
assert (
841
915
ctx .inverse_indices is not None
842
916
), "inverse indices must be provided from KJT if using variable batch size per feature."
843
- return VariableBatchEmbeddingBagCollectionAwaitable (
917
+ awaitable = VariableBatchEmbeddingBagCollectionAwaitable (
844
918
awaitables = awaitables ,
845
919
inverse_indices = ctx .inverse_indices ,
846
920
inverse_indices_permute_indices = self ._inverse_indices_permute_indices ,
@@ -851,12 +925,18 @@ def output_dist(
851
925
permute_op = self ._permute_op ,
852
926
)
853
927
else :
854
- return EmbeddingBagCollectionAwaitable (
928
+ awaitable = EmbeddingBagCollectionAwaitable (
855
929
awaitables = awaitables ,
856
930
embedding_dims = self ._embedding_dims ,
857
931
embedding_names = self ._embedding_names ,
858
932
)
859
933
934
+ # register callback if there are features that need mean pooling
935
+ if self ._has_mean_pooling_callback :
936
+ awaitable .callbacks .append (ctx .mean_pooling_callback )
937
+
938
+ return awaitable
939
+
860
940
def compute_and_output_dist (
861
941
self , ctx : EmbeddingBagCollectionContext , input : KJTList
862
942
) -> LazyAwaitable [KeyedTensor ]:
@@ -879,7 +959,7 @@ def compute_and_output_dist(
879
959
assert (
880
960
ctx .inverse_indices is not None
881
961
), "inverse indices must be provided from KJT if using variable batch size per feature."
882
- return VariableBatchEmbeddingBagCollectionAwaitable (
962
+ awaitable = VariableBatchEmbeddingBagCollectionAwaitable (
883
963
awaitables = awaitables ,
884
964
inverse_indices = ctx .inverse_indices ,
885
965
inverse_indices_permute_indices = self ._inverse_indices_permute_indices ,
@@ -890,12 +970,18 @@ def compute_and_output_dist(
890
970
permute_op = self ._permute_op ,
891
971
)
892
972
else :
893
- return EmbeddingBagCollectionAwaitable (
973
+ awaitable = EmbeddingBagCollectionAwaitable (
894
974
awaitables = awaitables ,
895
975
embedding_dims = self ._embedding_dims ,
896
976
embedding_names = self ._embedding_names ,
897
977
)
898
978
979
+ # register callback if there are features that need mean pooling
980
+ if self ._has_mean_pooling_callback :
981
+ awaitable .callbacks .append (ctx .mean_pooling_callback )
982
+
983
+ return awaitable
984
+
899
985
@property
900
986
def fused_optimizer (self ) -> KeyedOptimizer :
901
987
return self ._optim
@@ -1166,3 +1252,82 @@ def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Paramete
1166
1252
@property
1167
1253
def module_type (self ) -> Type [nn .EmbeddingBag ]:
1168
1254
return nn .EmbeddingBag
1255
+
1256
+
1257
+ def create_mean_pooling_callback (
1258
+ lengths : torch .Tensor ,
1259
+ keys : List [str ],
1260
+ stride : int ,
1261
+ stride_per_key : List [int ],
1262
+ dim_per_key : torch .Tensor ,
1263
+ pooling_type_to_rs_features : Dict [str , List [str ]],
1264
+ embedding_names : List [str ],
1265
+ embedding_dims : List [int ],
1266
+ variable_batch_per_feature : bool ,
1267
+ kjt_inverse_order : torch .Tensor ,
1268
+ kjt_key_indices : Dict [str , int ],
1269
+ kt_key_ordering : torch .Tensor ,
1270
+ inverse_indices : Optional [Tuple [List [str ], torch .Tensor ]] = None ,
1271
+ ) -> Callable [[KeyedTensor ], KeyedTensor ]:
1272
+ with record_function ("## ebc create mean pooling callback ##" ):
1273
+ batch_size = (
1274
+ inverse_indices [1 ].size (dim = 1 ) if variable_batch_per_feature else stride # pyre-ignore[16]
1275
+ )
1276
+
1277
+ if variable_batch_per_feature :
1278
+ device = inverse_indices [1 ].device
1279
+ inverse_indices_t = inverse_indices [1 ]
1280
+ if len (keys ) != len (inverse_indices [0 ]):
1281
+ inverse_indices_t = torch .index_select (
1282
+ inverse_indices [1 ], 0 , kjt_inverse_order
1283
+ )
1284
+ offsets = _to_offsets (torch .tensor (stride_per_key , device = device ))[
1285
+ :- 1
1286
+ ].unsqueeze (- 1 )
1287
+ indices = (inverse_indices_t + offsets ).flatten ()
1288
+ lengths = torch .index_select (input = lengths , dim = 0 , index = indices )
1289
+
1290
+ # only convert the sum pooling features to be 1 lengths
1291
+ for feature in pooling_type_to_rs_features [PoolingType .SUM .value ]:
1292
+ feature_index = kjt_key_indices [feature ]
1293
+ feature_index = feature_index * batch_size
1294
+ lengths [feature_index : feature_index + batch_size ] = 1
1295
+
1296
+ if len (embedding_names ) != len (keys ):
1297
+ lengths = torch .index_select (
1298
+ lengths .reshape (- 1 , batch_size ),
1299
+ 0 ,
1300
+ kt_key_ordering ,
1301
+ ).reshape (- 1 )
1302
+
1303
+ # transpose to align features with keyed tensor dim_per_key
1304
+ lengths = lengths .reshape (- 1 , batch_size ).T # [batch_size, num_features]
1305
+ output_size = sum (embedding_dims )
1306
+
1307
+ divisor = torch .repeat_interleave (
1308
+ input = lengths ,
1309
+ repeats = dim_per_key ,
1310
+ dim = 1 ,
1311
+ output_size = output_size ,
1312
+ )
1313
+ eps = 1e-6 # used to safe guard against 0 division
1314
+ divisor = divisor + eps
1315
+
1316
+ # pyre-ignore[53]
1317
+ def _apply_mean_pooling (keyed_tensor : KeyedTensor ) -> KeyedTensor :
1318
+ """
1319
+ Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes.
1320
+ This function is applied as a callback to the awaitable
1321
+ """
1322
+ with record_function ("## ebc apply mean pooling ##" ):
1323
+ mean_pooled_values = (
1324
+ keyed_tensor .values () / divisor
1325
+ ) # [batch size, num_features * embedding dim]
1326
+ return KeyedTensor (
1327
+ keys = keyed_tensor .keys (),
1328
+ values = mean_pooled_values ,
1329
+ length_per_key = keyed_tensor .length_per_key (),
1330
+ key_dim = 1 ,
1331
+ )
1332
+
1333
+ return _apply_mean_pooling
0 commit comments