Skip to content

Commit d7cee41

Browse files
dstaay-fbfacebook-github-bot
authored andcommitted
Back out "make EBC group on sharding type and compute kernel" (#2113)
Summary: Pull Request resolved: #2113 Original commit changeset: 920673d619e9 Original Phabricator Diff: D58160613 Jobs on MVAI failing: https://www.internalfb.com/mlhub/pipelines/runs/mast/fire-dstaay-cfr_roo_testing_v3_pp_dedic_strm?job_attempt=0&version=0&env=PRODUCTION&referrer=MAST_JOB_NOTIFICATION_BOT backout job is fine: fire-dstaay-cfr_roo_testing_v3_pp Reviewed By: sarckk Differential Revision: D58571298 fbshipit-source-id: 77078a9401803a26258a506d25bb9d4d7334dae4
1 parent 3ab0351 commit d7cee41

File tree

6 files changed

+168
-227
lines changed

6 files changed

+168
-227
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 51 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from functools import partial
1414
from typing import (
1515
Any,
16-
Callable,
1716
cast,
1817
Dict,
1918
Iterator,
@@ -38,7 +37,6 @@
3837
EmbeddingShardingInfo,
3938
KJTListSplitsAwaitable,
4039
Multistreamable,
41-
USE_ONE_TBE_PER_TABLE,
4240
)
4341
from torchrec.distributed.embedding_types import (
4442
BaseEmbeddingSharder,
@@ -75,7 +73,6 @@
7573
optimizer_type_to_emb_opt_type,
7674
)
7775
from torchrec.modules.embedding_configs import (
78-
BaseEmbeddingConfig,
7976
EmbeddingBagConfig,
8077
EmbeddingTableConfig,
8178
PoolingType,
@@ -144,6 +141,7 @@ def replace_placement_with_meta_device(
144141

145142

146143
def create_embedding_bag_sharding(
144+
sharding_type: str,
147145
sharding_infos: List[EmbeddingShardingInfo],
148146
env: ShardingEnv,
149147
device: Optional[torch.device] = None,
@@ -152,7 +150,6 @@ def create_embedding_bag_sharding(
152150
) -> EmbeddingSharding[
153151
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
154152
]:
155-
sharding_type = sharding_infos[0].param_sharding.sharding_type
156153
if device is not None and device.type == "meta":
157154
replace_placement_with_meta_device(sharding_infos)
158155
if sharding_type == ShardingType.TABLE_WISE.value:
@@ -198,48 +195,12 @@ def create_embedding_bag_sharding(
198195
raise ValueError(f"Sharding type not supported {sharding_type}")
199196

200197

201-
def get_sharding_group(
202-
config: BaseEmbeddingConfig,
203-
param_sharding: ParameterSharding,
204-
fused_params: Optional[Dict[str, Any]] = None,
205-
) -> str:
206-
if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False):
207-
return config.name
208-
if param_sharding.sharding_type in {
209-
ShardingType.COLUMN_WISE.value,
210-
ShardingType.TABLE_COLUMN_WISE.value,
211-
}:
212-
assert param_sharding.ranks
213-
num_ranks = len(param_sharding.ranks)
214-
assert config.embedding_dim % num_ranks == 0
215-
dim = config.embedding_dim // num_ranks
216-
else:
217-
dim = config.embedding_dim
218-
219-
group = f"{param_sharding.sharding_type}@{param_sharding.compute_kernel}"
220-
if (
221-
param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value
222-
and (
223-
(fused_params and fused_params.get("prefetch_pipeline", False))
224-
or (
225-
param_sharding.cache_params
226-
and param_sharding.cache_params.prefetch_pipeline
227-
)
228-
)
229-
):
230-
group += f"@{dim}"
231-
return group
232-
233-
234-
def create_sharding_infos_by_group(
198+
def create_sharding_infos_by_sharding(
235199
module: EmbeddingBagCollectionInterface,
236200
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
237201
prefix: str,
238202
fused_params: Optional[Dict[str, Any]],
239203
suffix: Optional[str] = "weight",
240-
group_fn: Optional[
241-
Callable[[EmbeddingBagConfig, ParameterSharding, Optional[Dict[str, Any]]], str]
242-
] = None,
243204
) -> Dict[str, List[EmbeddingShardingInfo]]:
244205

245206
if fused_params is None:
@@ -255,7 +216,7 @@ def create_sharding_infos_by_group(
255216
else:
256217
shared_feature[feature_name] = True
257218

258-
group_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = defaultdict(list)
219+
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = {}
259220

260221
# state_dict returns parameter.Tensor, which loses parameter level attributes
261222
parameter_by_name = dict(module.named_parameters())
@@ -288,6 +249,9 @@ def create_sharding_infos_by_group(
288249
assert param_name in parameter_by_name or param_name in state_dict
289250
param = parameter_by_name.get(param_name, state_dict[param_name])
290251

252+
if parameter_sharding.sharding_type not in sharding_type_to_sharding_infos:
253+
sharding_type_to_sharding_infos[parameter_sharding.sharding_type] = []
254+
291255
optimizer_params = getattr(param, "_optimizer_kwargs", [{}])
292256
optimizer_classes = getattr(param, "_optimizer_classes", [None])
293257

@@ -309,32 +273,28 @@ def create_sharding_infos_by_group(
309273
)
310274
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
311275

312-
group = (
313-
group_fn(config, parameter_sharding, fused_params)
314-
if group_fn is not None
315-
else parameter_sharding.sharding_type
316-
)
317-
sharding_info = EmbeddingShardingInfo(
318-
embedding_config=EmbeddingTableConfig(
319-
num_embeddings=config.num_embeddings,
320-
embedding_dim=config.embedding_dim,
321-
name=config.name,
322-
data_type=config.data_type,
323-
feature_names=copy.deepcopy(config.feature_names),
324-
pooling=config.pooling,
325-
is_weighted=module.is_weighted(),
326-
has_feature_processor=False,
327-
embedding_names=embedding_names,
328-
weight_init_max=config.weight_init_max,
329-
weight_init_min=config.weight_init_min,
330-
pruning_indices_remapping=config.pruning_indices_remapping,
331-
),
332-
param_sharding=parameter_sharding,
333-
param=param,
334-
fused_params=per_table_fused_params,
276+
sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
277+
EmbeddingShardingInfo(
278+
embedding_config=EmbeddingTableConfig(
279+
num_embeddings=config.num_embeddings,
280+
embedding_dim=config.embedding_dim,
281+
name=config.name,
282+
data_type=config.data_type,
283+
feature_names=copy.deepcopy(config.feature_names),
284+
pooling=config.pooling,
285+
is_weighted=module.is_weighted(),
286+
has_feature_processor=False,
287+
embedding_names=embedding_names,
288+
weight_init_max=config.weight_init_max,
289+
weight_init_min=config.weight_init_min,
290+
pruning_indices_remapping=config.pruning_indices_remapping,
291+
),
292+
param_sharding=parameter_sharding,
293+
param=param,
294+
fused_params=per_table_fused_params,
295+
)
335296
)
336-
group_to_sharding_infos[group].append(sharding_info)
337-
return group_to_sharding_infos
297+
return sharding_type_to_sharding_infos
338298

339299

340300
def create_sharding_infos_by_sharding_device_group(
@@ -611,30 +571,31 @@ def __init__(
611571
)
612572
self._env = env
613573

614-
group_to_sharding_infos = create_sharding_infos_by_group(
574+
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
615575
module,
616576
table_name_to_parameter_sharding,
617577
"embedding_bags.",
618578
fused_params,
619-
group_fn=get_sharding_group,
620579
)
621-
self._embedding_shardings: List[
580+
self._sharding_type_to_sharding: Dict[
581+
str,
622582
EmbeddingSharding[
623583
EmbeddingShardingContext,
624584
KeyedJaggedTensor,
625585
torch.Tensor,
626586
torch.Tensor,
627-
]
628-
] = [
629-
create_embedding_bag_sharding(
587+
],
588+
] = {
589+
sharding_type: create_embedding_bag_sharding(
590+
sharding_type,
630591
embedding_configs,
631592
env,
632593
device,
633594
permute_embeddings=True,
634595
qcomm_codecs_registry=self.qcomm_codecs_registry,
635596
)
636-
for embedding_configs in group_to_sharding_infos.values()
637-
]
597+
for sharding_type, embedding_configs in sharding_type_to_sharding_infos.items()
598+
}
638599

639600
self._is_weighted: bool = module.is_weighted()
640601
self._device = device
@@ -679,12 +640,15 @@ def __init__(
679640
optims.append(("", tbe_module.fused_optimizer))
680641
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
681642

682-
for i, (sharding, lookup) in enumerate(
683-
zip(self._embedding_shardings, self._lookups)
643+
for index, (sharding, lookup) in enumerate(
644+
zip(
645+
self._sharding_type_to_sharding.values(),
646+
self._lookups,
647+
)
684648
):
685649
# TODO: can move this into DpPooledEmbeddingSharding once all modules are composable
686650
if isinstance(sharding, DpPooledEmbeddingSharding):
687-
self._lookups[i] = DistributedDataParallel(
651+
self._lookups[index] = DistributedDataParallel(
688652
module=lookup,
689653
device_ids=(
690654
[device]
@@ -806,8 +770,10 @@ def _initialize_torch_state(self) -> None: # noqa
806770
table.embedding_dim,
807771
)
808772

809-
for lookup, sharding in zip(self._lookups, self._embedding_shardings):
810-
if isinstance(sharding, DpPooledEmbeddingSharding):
773+
for sharding_type, lookup in zip(
774+
self._sharding_type_to_sharding.keys(), self._lookups
775+
):
776+
if sharding_type == ShardingType.DATA_PARALLEL.value:
811777
# unwrap DDP
812778
lookup = lookup.module
813779
else:
@@ -898,7 +864,7 @@ def _create_input_dist(
898864
input_feature_names: List[str],
899865
) -> None:
900866
feature_names: List[str] = []
901-
for sharding in self._embedding_shardings:
867+
for sharding in self._sharding_type_to_sharding.values():
902868
self._input_dists.append(sharding.create_input_dist())
903869
feature_names.extend(sharding.feature_names())
904870
self._feature_splits.append(len(sharding.feature_names()))
@@ -924,7 +890,7 @@ def _init_mean_pooling_callback(
924890
# account for shared features
925891
feature_names: List[str] = [
926892
feature_name
927-
for sharding in self._embedding_shardings
893+
for sharding in self._sharding_type_to_sharding.values()
928894
for feature_name in sharding.feature_names()
929895
]
930896

@@ -951,12 +917,12 @@ def _init_mean_pooling_callback(
951917
def _create_lookups(
952918
self,
953919
) -> None:
954-
for sharding in self._embedding_shardings:
920+
for sharding in self._sharding_type_to_sharding.values():
955921
self._lookups.append(sharding.create_lookup())
956922

957923
def _create_output_dist(self) -> None:
958924
embedding_shard_metadata: List[Optional[ShardMetadata]] = []
959-
for sharding in self._embedding_shardings:
925+
for sharding in self._sharding_type_to_sharding.values():
960926
self._output_dists.append(sharding.create_output_dist(device=self._device))
961927
self._embedding_names.extend(sharding.embedding_names())
962928
self._embedding_dims.extend(sharding.embedding_dims())
@@ -1270,6 +1236,7 @@ def __init__(
12701236
self._embedding_sharding: EmbeddingSharding[
12711237
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
12721238
] = create_embedding_bag_sharding(
1239+
sharding_type=self.parameter_sharding.sharding_type,
12731240
sharding_infos=[
12741241
EmbeddingShardingInfo(
12751242
embedding_config=embedding_table_config,

torchrec/distributed/fused_embeddingbag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
)
6666

6767
for index, (sharding, lookup) in enumerate(
68-
zip(self._embedding_shardings, self._lookups)
68+
zip(self._sharding_type_to_sharding.values(), self._lookups)
6969
):
7070
if isinstance(sharding, DpPooledEmbeddingSharding):
7171
self._lookups[index] = DistributedDataParallel(

torchrec/distributed/mc_embedding_modules.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,14 @@ def __init__(
109109
# TODO: This is a hack since _embedding_module doesn't need input
110110
# dist, so eliminating it so all fused a2a will ignore it.
111111
self._embedding_module._has_uninitialized_input_dist = False
112-
embedding_shardings = (
113-
self._embedding_module._embedding_shardings
114-
if isinstance(self._embedding_module, ShardedEmbeddingBagCollection)
115-
else list(self._embedding_module._sharding_type_to_sharding.values())
116-
)
117112
self._managed_collision_collection: ShardedManagedCollisionCollection = (
118113
mc_sharder.shard(
119114
module._managed_collision_collection,
120115
table_name_to_parameter_sharding,
121116
env=env,
122117
device=device,
123-
embedding_shardings=embedding_shardings,
118+
# pyre-ignore
119+
sharding_type_to_sharding=self._embedding_module._sharding_type_to_sharding,
124120
)
125121
)
126122
self._return_remapped_features: bool = module._return_remapped_features

0 commit comments

Comments
 (0)