13
13
from functools import partial
14
14
from typing import (
15
15
Any ,
16
- Callable ,
17
16
cast ,
18
17
Dict ,
19
18
Iterator ,
38
37
EmbeddingShardingInfo ,
39
38
KJTListSplitsAwaitable ,
40
39
Multistreamable ,
41
- USE_ONE_TBE_PER_TABLE ,
42
40
)
43
41
from torchrec .distributed .embedding_types import (
44
42
BaseEmbeddingSharder ,
75
73
optimizer_type_to_emb_opt_type ,
76
74
)
77
75
from torchrec .modules .embedding_configs import (
78
- BaseEmbeddingConfig ,
79
76
EmbeddingBagConfig ,
80
77
EmbeddingTableConfig ,
81
78
PoolingType ,
@@ -144,6 +141,7 @@ def replace_placement_with_meta_device(
144
141
145
142
146
143
def create_embedding_bag_sharding (
144
+ sharding_type : str ,
147
145
sharding_infos : List [EmbeddingShardingInfo ],
148
146
env : ShardingEnv ,
149
147
device : Optional [torch .device ] = None ,
@@ -152,7 +150,6 @@ def create_embedding_bag_sharding(
152
150
) -> EmbeddingSharding [
153
151
EmbeddingShardingContext , KeyedJaggedTensor , torch .Tensor , torch .Tensor
154
152
]:
155
- sharding_type = sharding_infos [0 ].param_sharding .sharding_type
156
153
if device is not None and device .type == "meta" :
157
154
replace_placement_with_meta_device (sharding_infos )
158
155
if sharding_type == ShardingType .TABLE_WISE .value :
@@ -198,48 +195,12 @@ def create_embedding_bag_sharding(
198
195
raise ValueError (f"Sharding type not supported { sharding_type } " )
199
196
200
197
201
- def get_sharding_group (
202
- config : BaseEmbeddingConfig ,
203
- param_sharding : ParameterSharding ,
204
- fused_params : Optional [Dict [str , Any ]] = None ,
205
- ) -> str :
206
- if fused_params and fused_params .get (USE_ONE_TBE_PER_TABLE , False ):
207
- return config .name
208
- if param_sharding .sharding_type in {
209
- ShardingType .COLUMN_WISE .value ,
210
- ShardingType .TABLE_COLUMN_WISE .value ,
211
- }:
212
- assert param_sharding .ranks
213
- num_ranks = len (param_sharding .ranks )
214
- assert config .embedding_dim % num_ranks == 0
215
- dim = config .embedding_dim // num_ranks
216
- else :
217
- dim = config .embedding_dim
218
-
219
- group = f"{ param_sharding .sharding_type } @{ param_sharding .compute_kernel } "
220
- if (
221
- param_sharding .compute_kernel == EmbeddingComputeKernel .FUSED_UVM_CACHING .value
222
- and (
223
- (fused_params and fused_params .get ("prefetch_pipeline" , False ))
224
- or (
225
- param_sharding .cache_params
226
- and param_sharding .cache_params .prefetch_pipeline
227
- )
228
- )
229
- ):
230
- group += f"@{ dim } "
231
- return group
232
-
233
-
234
- def create_sharding_infos_by_group (
198
+ def create_sharding_infos_by_sharding (
235
199
module : EmbeddingBagCollectionInterface ,
236
200
table_name_to_parameter_sharding : Dict [str , ParameterSharding ],
237
201
prefix : str ,
238
202
fused_params : Optional [Dict [str , Any ]],
239
203
suffix : Optional [str ] = "weight" ,
240
- group_fn : Optional [
241
- Callable [[EmbeddingBagConfig , ParameterSharding , Optional [Dict [str , Any ]]], str ]
242
- ] = None ,
243
204
) -> Dict [str , List [EmbeddingShardingInfo ]]:
244
205
245
206
if fused_params is None :
@@ -255,7 +216,7 @@ def create_sharding_infos_by_group(
255
216
else :
256
217
shared_feature [feature_name ] = True
257
218
258
- group_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = defaultdict ( list )
219
+ sharding_type_to_sharding_infos : Dict [str , List [EmbeddingShardingInfo ]] = {}
259
220
260
221
# state_dict returns parameter.Tensor, which loses parameter level attributes
261
222
parameter_by_name = dict (module .named_parameters ())
@@ -288,6 +249,9 @@ def create_sharding_infos_by_group(
288
249
assert param_name in parameter_by_name or param_name in state_dict
289
250
param = parameter_by_name .get (param_name , state_dict [param_name ])
290
251
252
+ if parameter_sharding .sharding_type not in sharding_type_to_sharding_infos :
253
+ sharding_type_to_sharding_infos [parameter_sharding .sharding_type ] = []
254
+
291
255
optimizer_params = getattr (param , "_optimizer_kwargs" , [{}])
292
256
optimizer_classes = getattr (param , "_optimizer_classes" , [None ])
293
257
@@ -309,32 +273,28 @@ def create_sharding_infos_by_group(
309
273
)
310
274
per_table_fused_params = convert_to_fbgemm_types (per_table_fused_params )
311
275
312
- group = (
313
- group_fn (config , parameter_sharding , fused_params )
314
- if group_fn is not None
315
- else parameter_sharding .sharding_type
316
- )
317
- sharding_info = EmbeddingShardingInfo (
318
- embedding_config = EmbeddingTableConfig (
319
- num_embeddings = config .num_embeddings ,
320
- embedding_dim = config .embedding_dim ,
321
- name = config .name ,
322
- data_type = config .data_type ,
323
- feature_names = copy .deepcopy (config .feature_names ),
324
- pooling = config .pooling ,
325
- is_weighted = module .is_weighted (),
326
- has_feature_processor = False ,
327
- embedding_names = embedding_names ,
328
- weight_init_max = config .weight_init_max ,
329
- weight_init_min = config .weight_init_min ,
330
- pruning_indices_remapping = config .pruning_indices_remapping ,
331
- ),
332
- param_sharding = parameter_sharding ,
333
- param = param ,
334
- fused_params = per_table_fused_params ,
276
+ sharding_type_to_sharding_infos [parameter_sharding .sharding_type ].append (
277
+ EmbeddingShardingInfo (
278
+ embedding_config = EmbeddingTableConfig (
279
+ num_embeddings = config .num_embeddings ,
280
+ embedding_dim = config .embedding_dim ,
281
+ name = config .name ,
282
+ data_type = config .data_type ,
283
+ feature_names = copy .deepcopy (config .feature_names ),
284
+ pooling = config .pooling ,
285
+ is_weighted = module .is_weighted (),
286
+ has_feature_processor = False ,
287
+ embedding_names = embedding_names ,
288
+ weight_init_max = config .weight_init_max ,
289
+ weight_init_min = config .weight_init_min ,
290
+ pruning_indices_remapping = config .pruning_indices_remapping ,
291
+ ),
292
+ param_sharding = parameter_sharding ,
293
+ param = param ,
294
+ fused_params = per_table_fused_params ,
295
+ )
335
296
)
336
- group_to_sharding_infos [group ].append (sharding_info )
337
- return group_to_sharding_infos
297
+ return sharding_type_to_sharding_infos
338
298
339
299
340
300
def create_sharding_infos_by_sharding_device_group (
@@ -611,30 +571,31 @@ def __init__(
611
571
)
612
572
self ._env = env
613
573
614
- group_to_sharding_infos = create_sharding_infos_by_group (
574
+ sharding_type_to_sharding_infos = create_sharding_infos_by_sharding (
615
575
module ,
616
576
table_name_to_parameter_sharding ,
617
577
"embedding_bags." ,
618
578
fused_params ,
619
- group_fn = get_sharding_group ,
620
579
)
621
- self ._embedding_shardings : List [
580
+ self ._sharding_type_to_sharding : Dict [
581
+ str ,
622
582
EmbeddingSharding [
623
583
EmbeddingShardingContext ,
624
584
KeyedJaggedTensor ,
625
585
torch .Tensor ,
626
586
torch .Tensor ,
627
- ]
628
- ] = [
629
- create_embedding_bag_sharding (
587
+ ],
588
+ ] = {
589
+ sharding_type : create_embedding_bag_sharding (
590
+ sharding_type ,
630
591
embedding_configs ,
631
592
env ,
632
593
device ,
633
594
permute_embeddings = True ,
634
595
qcomm_codecs_registry = self .qcomm_codecs_registry ,
635
596
)
636
- for embedding_configs in group_to_sharding_infos . values ()
637
- ]
597
+ for sharding_type , embedding_configs in sharding_type_to_sharding_infos . items ()
598
+ }
638
599
639
600
self ._is_weighted : bool = module .is_weighted ()
640
601
self ._device = device
@@ -679,12 +640,15 @@ def __init__(
679
640
optims .append (("" , tbe_module .fused_optimizer ))
680
641
self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
681
642
682
- for i , (sharding , lookup ) in enumerate (
683
- zip (self ._embedding_shardings , self ._lookups )
643
+ for index , (sharding , lookup ) in enumerate (
644
+ zip (
645
+ self ._sharding_type_to_sharding .values (),
646
+ self ._lookups ,
647
+ )
684
648
):
685
649
# TODO: can move this into DpPooledEmbeddingSharding once all modules are composable
686
650
if isinstance (sharding , DpPooledEmbeddingSharding ):
687
- self ._lookups [i ] = DistributedDataParallel (
651
+ self ._lookups [index ] = DistributedDataParallel (
688
652
module = lookup ,
689
653
device_ids = (
690
654
[device ]
@@ -806,8 +770,10 @@ def _initialize_torch_state(self) -> None: # noqa
806
770
table .embedding_dim ,
807
771
)
808
772
809
- for lookup , sharding in zip (self ._lookups , self ._embedding_shardings ):
810
- if isinstance (sharding , DpPooledEmbeddingSharding ):
773
+ for sharding_type , lookup in zip (
774
+ self ._sharding_type_to_sharding .keys (), self ._lookups
775
+ ):
776
+ if sharding_type == ShardingType .DATA_PARALLEL .value :
811
777
# unwrap DDP
812
778
lookup = lookup .module
813
779
else :
@@ -898,7 +864,7 @@ def _create_input_dist(
898
864
input_feature_names : List [str ],
899
865
) -> None :
900
866
feature_names : List [str ] = []
901
- for sharding in self ._embedding_shardings :
867
+ for sharding in self ._sharding_type_to_sharding . values () :
902
868
self ._input_dists .append (sharding .create_input_dist ())
903
869
feature_names .extend (sharding .feature_names ())
904
870
self ._feature_splits .append (len (sharding .feature_names ()))
@@ -924,7 +890,7 @@ def _init_mean_pooling_callback(
924
890
# account for shared features
925
891
feature_names : List [str ] = [
926
892
feature_name
927
- for sharding in self ._embedding_shardings
893
+ for sharding in self ._sharding_type_to_sharding . values ()
928
894
for feature_name in sharding .feature_names ()
929
895
]
930
896
@@ -951,12 +917,12 @@ def _init_mean_pooling_callback(
951
917
def _create_lookups (
952
918
self ,
953
919
) -> None :
954
- for sharding in self ._embedding_shardings :
920
+ for sharding in self ._sharding_type_to_sharding . values () :
955
921
self ._lookups .append (sharding .create_lookup ())
956
922
957
923
def _create_output_dist (self ) -> None :
958
924
embedding_shard_metadata : List [Optional [ShardMetadata ]] = []
959
- for sharding in self ._embedding_shardings :
925
+ for sharding in self ._sharding_type_to_sharding . values () :
960
926
self ._output_dists .append (sharding .create_output_dist (device = self ._device ))
961
927
self ._embedding_names .extend (sharding .embedding_names ())
962
928
self ._embedding_dims .extend (sharding .embedding_dims ())
@@ -1270,6 +1236,7 @@ def __init__(
1270
1236
self ._embedding_sharding : EmbeddingSharding [
1271
1237
EmbeddingShardingContext , KeyedJaggedTensor , torch .Tensor , torch .Tensor
1272
1238
] = create_embedding_bag_sharding (
1239
+ sharding_type = self .parameter_sharding .sharding_type ,
1273
1240
sharding_infos = [
1274
1241
EmbeddingShardingInfo (
1275
1242
embedding_config = embedding_table_config ,
0 commit comments