Skip to content

revert sharding grouping logic for vbe #2216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
EmbeddingComputeKernel,
GroupedEmbeddingConfig,
InputDistOutputs,
KJTList,
)
from torchrec.distributed.fused_params import (
get_tbes_to_register_from_iterable,
Expand Down Expand Up @@ -442,8 +441,9 @@ def _create_lookup(
self.grouped_configs = grouped_configs
self._feature_processor = feature_processor

self._world_size: int = dist.get_world_size(pg)
self._scale_gradient_factor: int = (
dist.get_world_size(pg)
self._world_size
if scale_weight_gradients and get_gradient_division()
else 1
)
Expand Down Expand Up @@ -487,11 +487,24 @@ def _need_prefetch(config: GroupedEmbeddingConfig) -> bool:
),
)

def _merge_variable_batch_embeddings(
self, embeddings: List[torch.Tensor], splits: List[List[int]]
) -> List[torch.Tensor]:
split_embs = [e.split(s) for e, s in zip(embeddings, splits)]
combined_embs = [
emb
for rank in range(self._world_size)
for n, embs in zip(self._feature_splits, split_embs)
for emb in embs[n * rank : n * rank + n]
]
return [torch.cat(combined_embs)]

def forward(
self,
sparse_features: KeyedJaggedTensor,
) -> torch.Tensor:
embeddings: List[torch.Tensor] = []
vbe_splits = []
if len(self._emb_modules) > 0:
assert sparse_features is not None
features_by_group = sparse_features.split(
Expand All @@ -514,6 +527,23 @@ def forward(

embeddings.append(emb_op(features))

if features.variable_stride_per_key():
stride_per_rank_per_key = list(
zip(*features.stride_per_key_per_rank())
)
vbe_splits.append(
[
stride * dim
for stride_per_rank in stride_per_rank_per_key
for stride, dim in zip(
stride_per_rank, config.embedding_dims()
)
]
)

if sparse_features.variable_stride_per_key():
embeddings = self._merge_variable_batch_embeddings(embeddings, vbe_splits)

dummy_embedding = (
self._dummy_embs_tensor
if sparse_features.variable_stride_per_key()
Expand Down
49 changes: 10 additions & 39 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from functools import partial
from typing import (
Any,
Callable,
cast,
Dict,
Iterator,
Expand All @@ -39,7 +38,6 @@
EmbeddingShardingInfo,
KJTListSplitsAwaitable,
Multistreamable,
USE_ONE_TBE_PER_TABLE,
)
from torchrec.distributed.embedding_types import (
BaseEmbeddingSharder,
Expand Down Expand Up @@ -77,7 +75,6 @@
optimizer_type_to_emb_opt_type,
)
from torchrec.modules.embedding_configs import (
BaseEmbeddingConfig,
EmbeddingBagConfig,
EmbeddingTableConfig,
PoolingType,
Expand Down Expand Up @@ -200,36 +197,7 @@ def create_embedding_bag_sharding(
raise ValueError(f"Sharding type not supported {sharding_type}")


def get_sharding_group(
config: BaseEmbeddingConfig,
param_sharding: ParameterSharding,
fused_params: Optional[Dict[str, Any]] = None,
) -> str:
if fused_params and fused_params.get(USE_ONE_TBE_PER_TABLE, False):
return config.name
if param_sharding.sharding_type in {
ShardingType.COLUMN_WISE.value,
ShardingType.TABLE_COLUMN_WISE.value,
}:
assert param_sharding.ranks
num_ranks = len(param_sharding.ranks)
assert config.embedding_dim % num_ranks == 0
dim = config.embedding_dim // num_ranks
else:
dim = config.embedding_dim

group = f"{param_sharding.sharding_type}"
if param_sharding.compute_kernel == EmbeddingComputeKernel.FUSED_UVM_CACHING.value:
group += f"@{param_sharding.compute_kernel}"
if (fused_params and fused_params.get("prefetch_pipeline", False)) or (
param_sharding.cache_params
and param_sharding.cache_params.prefetch_pipeline
):
group += f"@{dim}"
return group


def create_sharding_infos_by_group(
def create_sharding_infos_by_sharding(
module: EmbeddingBagCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
prefix: str,
Expand All @@ -250,7 +218,9 @@ def create_sharding_infos_by_group(
else:
shared_feature[feature_name] = True

group_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = defaultdict(list)
sharding_type_to_sharding_infos: Dict[str, List[EmbeddingShardingInfo]] = (
defaultdict(list)
)

# state_dict returns parameter.Tensor, which loses parameter level attributes
parameter_by_name = dict(module.named_parameters())
Expand Down Expand Up @@ -304,7 +274,6 @@ def create_sharding_infos_by_group(
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)

group = get_sharding_group(config, parameter_sharding, fused_params)
sharding_info = EmbeddingShardingInfo(
embedding_config=EmbeddingTableConfig(
num_embeddings=config.num_embeddings,
Expand All @@ -324,8 +293,10 @@ def create_sharding_infos_by_group(
param=param,
fused_params=per_table_fused_params,
)
group_to_sharding_infos[group].append(sharding_info)
return group_to_sharding_infos
sharding_type_to_sharding_infos[parameter_sharding.sharding_type].append(
sharding_info
)
return sharding_type_to_sharding_infos


def create_sharding_infos_by_sharding_device_group(
Expand Down Expand Up @@ -602,7 +573,7 @@ def __init__(
)
self._env = env

group_to_sharding_infos = create_sharding_infos_by_group(
sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
module,
table_name_to_parameter_sharding,
"embedding_bags.",
Expand All @@ -623,7 +594,7 @@ def __init__(
permute_embeddings=True,
qcomm_codecs_registry=self.qcomm_codecs_registry,
)
for embedding_configs in group_to_sharding_infos.values()
for embedding_configs in sharding_type_to_sharding_infos.values()
]

self._is_weighted: bool = module.is_weighted()
Expand Down
12 changes: 6 additions & 6 deletions torchrec/distributed/planner/tests/test_embeddingbag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import unittest

from torchrec.distributed.embeddingbag import (
create_sharding_infos_by_group,
create_sharding_infos_by_sharding,
EmbeddingBagCollectionSharder,
)
from torchrec.distributed.planner import (
Expand Down Expand Up @@ -79,7 +79,7 @@ def setUp(self) -> None:
)
self.expected_plan = planner.plan(self.model, [self.sharder]) # pyre-ignore[6]

self.expected_sharding_infos = create_sharding_infos_by_group(
self.expected_sharding_infos = create_sharding_infos_by_sharding(
self.model,
self.expected_plan.get_plan_for_module(""), # pyre-ignore[6]
prefix="embedding_bags.",
Expand All @@ -93,7 +93,7 @@ def test_create_sharding_infos_by_group_override(self) -> None:

# with sharder fused params that will get overridden
sharder_fused_params = {"enforce_hbm": False}
overriden_sharding_infos = create_sharding_infos_by_group(
overriden_sharding_infos = create_sharding_infos_by_sharding(
self.model,
self.expected_plan.get_plan_for_module(""),
prefix="embedding_bags.",
Expand All @@ -106,7 +106,7 @@ def test_create_sharding_infos_by_group_override(self) -> None:

# with sharder fused params that won't get overridden
sharder_fused_params = {"ABC": True}
not_overriden_sharding_infos = create_sharding_infos_by_group(
not_overriden_sharding_infos = create_sharding_infos_by_sharding(
self.model,
self.expected_plan.get_plan_for_module(""),
prefix="embedding_bags.",
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None:
# provide that two fused params from sharder
sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": False}

combined_sharding_infos = create_sharding_infos_by_group(
combined_sharding_infos = create_sharding_infos_by_sharding(
self.model,
new_plan.get_plan_for_module(""), # pyre-ignore[6]
prefix="embedding_bags.",
Expand All @@ -156,7 +156,7 @@ def test_create_sharding_infos_by_group_combine(self) -> None:

# provide that two fused params from sharder wrongly
sharder_fused_params = {"enforce_hbm": True, "stochastic_rounding": True}
wrong_combined_sharding_infos = create_sharding_infos_by_group(
wrong_combined_sharding_infos = create_sharding_infos_by_sharding(
self.model,
new_plan.get_plan_for_module(""), # pyre-ignore[6]
prefix="embedding_bags.",
Expand Down
3 changes: 2 additions & 1 deletion torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,9 +653,10 @@ def test_sharding_multiple_kernels(self, sharding_type: str) -> None:
)
for i, table in enumerate(self.tables)
}
fused_params = {"prefetch_pipeline": True}
self._test_sharding(
# pyre-ignore[6]
sharders=[EmbeddingBagCollectionSharder()],
sharders=[EmbeddingBagCollectionSharder(fused_params=fused_params)],
backend=self.backend,
constraints=constraints,
variable_batch_per_feature=True,
Expand Down
6 changes: 4 additions & 2 deletions torchrec/distributed/tests/test_pt2_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,15 +518,17 @@ def disable_cuda_tf32(self) -> bool:
ShardingType.TABLE_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
# TODO: Revert to "inductor" once https://github.com/pytorch/pytorch/pull/130431 is landed
"eager",
_TestConfig(),
),
(
_ModelType.EBC,
ShardingType.COLUMN_WISE.value,
_InputType.SINGLE_BATCH,
_ConvertToVariableBatch.TRUE,
"inductor",
# TODO: Revert to "inductor" once https://github.com/pytorch/pytorch/pull/130431 is landed
"eager",
_TestConfig(),
),
(
Expand Down
Loading