Skip to content

Commit c252dd9

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
mean pooling in EBC/VBE (#1772)
Summary: This diff supports mean pooling for Row Wise/Table Row Wise sharding schemes. This is achieved through applying mean pooling post reduce scatter collective as the KeyedTensor awaitable is created. The implementation is done through a callback. Differential Revision: D54656612
1 parent 01b5d34 commit c252dd9

File tree

7 files changed

+291
-25
lines changed

7 files changed

+291
-25
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
Shard,
5151
ShardedTensor,
5252
ShardedTensorMetadata,
53+
ShardingType,
5354
ShardMetadata,
5455
TensorProperties,
5556
)
@@ -720,13 +721,16 @@ def __init__(
720721
config: GroupedEmbeddingConfig,
721722
pg: Optional[dist.ProcessGroup] = None,
722723
device: Optional[torch.device] = None,
724+
sharding_type: Optional[ShardingType] = None,
723725
) -> None:
724726
super().__init__()
725727
torch._C._log_api_usage_once(f"torchrec.distributed.{self.__class__.__name__}")
726728
self._config = config
727729
self._pg = pg
728730

729-
self._pooling: PoolingMode = pooling_type_to_pooling_mode(config.pooling)
731+
self._pooling: PoolingMode = pooling_type_to_pooling_mode(
732+
config.pooling, sharding_type
733+
)
730734

731735
self._local_rows: List[int] = []
732736
self._weight_init_mins: List[float] = []
@@ -859,8 +863,9 @@ def __init__(
859863
config: GroupedEmbeddingConfig,
860864
pg: Optional[dist.ProcessGroup] = None,
861865
device: Optional[torch.device] = None,
866+
sharding_type: Optional[ShardingType] = None,
862867
) -> None:
863-
super().__init__(config, pg, device)
868+
super().__init__(config, pg, device, sharding_type)
864869

865870
managed: List[EmbeddingLocation] = []
866871
compute_devices: List[ComputeDevice] = []
@@ -962,8 +967,9 @@ def __init__(
962967
config: GroupedEmbeddingConfig,
963968
pg: Optional[dist.ProcessGroup] = None,
964969
device: Optional[torch.device] = None,
970+
sharding_type: Optional[ShardingType] = None,
965971
) -> None:
966-
super().__init__(config, pg, device)
972+
super().__init__(config, pg, device, sharding_type)
967973

968974
weights_precision = data_type_to_sparse_type(config.data_type)
969975
fused_params = config.fused_params or {}

torchrec/distributed/embedding_lookup.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
QuantBatchedEmbedding,
5353
QuantBatchedEmbeddingBag,
5454
)
55-
from torchrec.distributed.types import ShardedTensor
55+
from torchrec.distributed.types import ShardedTensor, ShardingType
5656
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
5757

5858
logger: logging.Logger = logging.getLogger(__name__)
@@ -344,23 +344,27 @@ def __init__(
344344
pg: Optional[dist.ProcessGroup] = None,
345345
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
346346
scale_weight_gradients: bool = True,
347+
sharding_type: Optional[ShardingType] = None,
347348
) -> None:
348349
# TODO rename to _create_embedding_kernel
349350
def _create_lookup(
350351
config: GroupedEmbeddingConfig,
351352
device: Optional[torch.device] = None,
353+
sharding_type: Optional[ShardingType] = None,
352354
) -> BaseEmbedding:
353355
if config.compute_kernel == EmbeddingComputeKernel.DENSE:
354356
return BatchedDenseEmbeddingBag(
355357
config=config,
356358
pg=pg,
357359
device=device,
360+
sharding_type=sharding_type,
358361
)
359362
elif config.compute_kernel == EmbeddingComputeKernel.FUSED:
360363
return BatchedFusedEmbeddingBag(
361364
config=config,
362365
pg=pg,
363366
device=device,
367+
sharding_type=sharding_type,
364368
)
365369
else:
366370
raise ValueError(
@@ -370,7 +374,7 @@ def _create_lookup(
370374
super().__init__()
371375
self._emb_modules: nn.ModuleList = nn.ModuleList()
372376
for config in grouped_configs:
373-
self._emb_modules.append(_create_lookup(config, device))
377+
self._emb_modules.append(_create_lookup(config, device, sharding_type))
374378

375379
self._feature_splits: List[int] = []
376380
for config in grouped_configs:

torchrec/distributed/embeddingbag.py

Lines changed: 177 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
# pyre-strict
99

1010
import copy
11-
from collections import OrderedDict
11+
from collections import defaultdict, OrderedDict
1212
from dataclasses import dataclass, field
1313
from typing import (
1414
Any,
15+
Callable,
1516
cast,
1617
Dict,
1718
Iterator,
@@ -27,6 +28,7 @@
2728
import torch
2829
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
2930
from torch import nn, Tensor
31+
from torch.autograd.profiler import record_function
3032
from torch.nn.modules.module import _IncompatibleKeys
3133
from torch.nn.parallel import DistributedDataParallel
3234
from torchrec.distributed.embedding_sharding import (
@@ -79,7 +81,7 @@
7981
)
8082
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
8183
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
82-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
84+
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
8385

8486
try:
8587
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -378,6 +380,7 @@ class EmbeddingBagCollectionContext(Multistreamable):
378380
)
379381
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None
380382
variable_batch_per_feature: bool = False
383+
mean_pooling_callback: Optional[Callable[[KeyedTensor], KeyedTensor]] = None
381384

382385
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
383386
for ctx in self.sharding_contexts:
@@ -415,13 +418,22 @@ def __init__(
415418
self._embedding_bag_configs: List[EmbeddingBagConfig] = (
416419
module.embedding_bag_configs()
417420
)
418-
self._table_names: List[str] = [
419-
config.name for config in self._embedding_bag_configs
420-
]
421421

422-
self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {
423-
config.name: config for config in self._embedding_bag_configs
424-
}
422+
self._table_names: List[str] = []
423+
self._pooling_type_to_rs_features: Dict[str, List[str]] = defaultdict(list)
424+
self._table_name_to_config: Dict[str, EmbeddingBagConfig] = {}
425+
426+
for config in self._embedding_bag_configs:
427+
self._table_names.append(config.name)
428+
self._table_name_to_config[config.name] = config
429+
430+
if table_name_to_parameter_sharding[config.name].sharding_type in [
431+
ShardingType.TABLE_ROW_WISE.value,
432+
ShardingType.ROW_WISE.value,
433+
]:
434+
self._pooling_type_to_rs_features[config.pooling.value].extend(
435+
config.feature_names
436+
)
425437

426438
self.module_sharding_plan: EmbeddingModuleShardingPlan = cast(
427439
EmbeddingModuleShardingPlan,
@@ -472,6 +484,16 @@ def __init__(
472484
self._uncombined_embedding_names: List[str] = []
473485
self._uncombined_embedding_dims: List[int] = []
474486
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
487+
# to support mean pooling callback hook
488+
self._has_mean_pooling_callback: bool = (
489+
True
490+
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
491+
else False
492+
)
493+
self._dim_per_key: Optional[torch.Tensor] = None
494+
self._kjt_key_indices: Dict[str, int] = {}
495+
self._kjt_inverse_order: Optional[torch.Tensor] = None
496+
self._kt_key_ordering: Optional[torch.Tensor] = None
475497
# to support the FP16 hook
476498
self._create_output_dist()
477499

@@ -720,6 +742,38 @@ def _create_input_dist(
720742
persistent=False,
721743
)
722744

745+
def _init_mean_pooling_callback(
746+
self,
747+
input_feature_names: List[str],
748+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]],
749+
) -> None:
750+
# account for shared features
751+
feature_names: List[str] = [
752+
feature_name
753+
for sharding in self._sharding_type_to_sharding.values()
754+
for feature_name in sharding.feature_names()
755+
]
756+
757+
for i, key in enumerate(feature_names):
758+
if key not in self._kjt_key_indices: # index of first occurence
759+
self._kjt_key_indices[key] = i
760+
761+
keyed_tensor_ordering = []
762+
for key in self._embedding_names:
763+
if "@" in key:
764+
key = key.split("@")[0]
765+
keyed_tensor_ordering.append(self._kjt_key_indices[key])
766+
self._kt_key_ordering = torch.tensor(keyed_tensor_ordering, device=self._device)
767+
768+
if inverse_indices:
769+
key_to_inverse_index = {
770+
name: i for i, name in enumerate(inverse_indices[0])
771+
}
772+
self._kjt_inverse_order = torch.tensor(
773+
[key_to_inverse_index[key] for key in feature_names],
774+
device=self._device,
775+
)
776+
723777
def _create_lookups(
724778
self,
725779
) -> None:
@@ -737,6 +791,7 @@ def _create_output_dist(self) -> None:
737791
)
738792
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
739793
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
794+
self._dim_per_key = torch.tensor(self._embedding_dims, device=self._device)
740795
embedding_shard_offsets: List[int] = [
741796
meta.shard_offsets[1] if meta is not None else 0
742797
for meta in embedding_shard_metadata
@@ -789,12 +844,31 @@ def input_dist(
789844
self._has_uninitialized_input_dist = False
790845
if ctx.variable_batch_per_feature:
791846
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
847+
if self._has_mean_pooling_callback:
848+
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
792849
with torch.no_grad():
793850
if self._has_features_permute:
794851
features = features.permute(
795852
self._features_order,
796853
self._features_order_tensor,
797854
)
855+
if self._has_mean_pooling_callback:
856+
ctx.mean_pooling_callback = create_mean_pooling_callback(
857+
lengths=features.lengths(),
858+
stride=features.stride(),
859+
keys=features.keys(),
860+
pooling_type_to_rs_features=self._pooling_type_to_rs_features,
861+
stride_per_key=features.stride_per_key(),
862+
dim_per_key=self._dim_per_key, # pyre-ignore[6]
863+
embedding_names=self._embedding_names,
864+
embedding_dims=self._embedding_dims,
865+
variable_batch_per_feature=ctx.variable_batch_per_feature,
866+
kjt_inverse_order=self._kjt_inverse_order, # pyre-ignore[6]
867+
kjt_key_indices=self._kjt_key_indices,
868+
kt_key_ordering=self._kt_key_ordering, # pyre-ignore[6]
869+
inverse_indices=ctx.inverse_indices,
870+
)
871+
798872
features_by_shards = features.split(
799873
self._feature_splits,
800874
)
@@ -840,7 +914,7 @@ def output_dist(
840914
assert (
841915
ctx.inverse_indices is not None
842916
), "inverse indices must be provided from KJT if using variable batch size per feature."
843-
return VariableBatchEmbeddingBagCollectionAwaitable(
917+
awaitable = VariableBatchEmbeddingBagCollectionAwaitable(
844918
awaitables=awaitables,
845919
inverse_indices=ctx.inverse_indices,
846920
inverse_indices_permute_indices=self._inverse_indices_permute_indices,
@@ -851,12 +925,18 @@ def output_dist(
851925
permute_op=self._permute_op,
852926
)
853927
else:
854-
return EmbeddingBagCollectionAwaitable(
928+
awaitable = EmbeddingBagCollectionAwaitable(
855929
awaitables=awaitables,
856930
embedding_dims=self._embedding_dims,
857931
embedding_names=self._embedding_names,
858932
)
859933

934+
# register callback if there are features that need mean pooling
935+
if self._has_mean_pooling_callback:
936+
awaitable.callbacks.append(ctx.mean_pooling_callback)
937+
938+
return awaitable
939+
860940
def compute_and_output_dist(
861941
self, ctx: EmbeddingBagCollectionContext, input: KJTList
862942
) -> LazyAwaitable[KeyedTensor]:
@@ -879,7 +959,7 @@ def compute_and_output_dist(
879959
assert (
880960
ctx.inverse_indices is not None
881961
), "inverse indices must be provided from KJT if using variable batch size per feature."
882-
return VariableBatchEmbeddingBagCollectionAwaitable(
962+
awaitable = VariableBatchEmbeddingBagCollectionAwaitable(
883963
awaitables=awaitables,
884964
inverse_indices=ctx.inverse_indices,
885965
inverse_indices_permute_indices=self._inverse_indices_permute_indices,
@@ -890,12 +970,18 @@ def compute_and_output_dist(
890970
permute_op=self._permute_op,
891971
)
892972
else:
893-
return EmbeddingBagCollectionAwaitable(
973+
awaitable = EmbeddingBagCollectionAwaitable(
894974
awaitables=awaitables,
895975
embedding_dims=self._embedding_dims,
896976
embedding_names=self._embedding_names,
897977
)
898978

979+
# register callback if there are features that need mean pooling
980+
if self._has_mean_pooling_callback:
981+
awaitable.callbacks.append(ctx.mean_pooling_callback)
982+
983+
return awaitable
984+
899985
@property
900986
def fused_optimizer(self) -> KeyedOptimizer:
901987
return self._optim
@@ -1166,3 +1252,82 @@ def shardable_parameters(self, module: nn.EmbeddingBag) -> Dict[str, nn.Paramete
11661252
@property
11671253
def module_type(self) -> Type[nn.EmbeddingBag]:
11681254
return nn.EmbeddingBag
1255+
1256+
1257+
def create_mean_pooling_callback(
1258+
lengths: torch.Tensor,
1259+
keys: List[str],
1260+
stride: int,
1261+
stride_per_key: List[int],
1262+
dim_per_key: torch.Tensor,
1263+
pooling_type_to_rs_features: Dict[str, List[str]],
1264+
embedding_names: List[str],
1265+
embedding_dims: List[int],
1266+
variable_batch_per_feature: bool,
1267+
kjt_inverse_order: torch.Tensor,
1268+
kjt_key_indices: Dict[str, int],
1269+
kt_key_ordering: torch.Tensor,
1270+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
1271+
) -> Callable[[KeyedTensor], KeyedTensor]:
1272+
with record_function("## ebc create mean pooling callback ##"):
1273+
batch_size = (
1274+
inverse_indices[1].size(dim=1) if variable_batch_per_feature else stride # pyre-ignore[16]
1275+
)
1276+
1277+
if variable_batch_per_feature:
1278+
device = inverse_indices[1].device
1279+
inverse_indices_t = inverse_indices[1]
1280+
if len(keys) != len(inverse_indices[0]):
1281+
inverse_indices_t = torch.index_select(
1282+
inverse_indices[1], 0, kjt_inverse_order
1283+
)
1284+
offsets = _to_offsets(torch.tensor(stride_per_key, device=device))[
1285+
:-1
1286+
].unsqueeze(-1)
1287+
indices = (inverse_indices_t + offsets).flatten()
1288+
lengths = torch.index_select(input=lengths, dim=0, index=indices)
1289+
1290+
# only convert the sum pooling features to be 1 lengths
1291+
for feature in pooling_type_to_rs_features[PoolingType.SUM.value]:
1292+
feature_index = kjt_key_indices[feature]
1293+
feature_index = feature_index * batch_size
1294+
lengths[feature_index : feature_index + batch_size] = 1
1295+
1296+
if len(embedding_names) != len(keys):
1297+
lengths = torch.index_select(
1298+
lengths.reshape(-1, batch_size),
1299+
0,
1300+
kt_key_ordering,
1301+
).reshape(-1)
1302+
1303+
# transpose to align features with keyed tensor dim_per_key
1304+
lengths = lengths.reshape(-1, batch_size).T # [batch_size, num_features]
1305+
output_size = sum(embedding_dims)
1306+
1307+
divisor = torch.repeat_interleave(
1308+
input=lengths,
1309+
repeats=dim_per_key,
1310+
dim=1,
1311+
output_size=output_size,
1312+
)
1313+
eps = 1e-6 # used to safe guard against 0 division
1314+
divisor = divisor + eps
1315+
1316+
# pyre-ignore[53]
1317+
def _apply_mean_pooling(keyed_tensor: KeyedTensor) -> KeyedTensor:
1318+
"""
1319+
Apply mean pooling to pooled embeddings in RW/TWRW sharding schemes.
1320+
This function is applied as a callback to the awaitable
1321+
"""
1322+
with record_function("## ebc apply mean pooling ##"):
1323+
mean_pooled_values = (
1324+
keyed_tensor.values() / divisor
1325+
) # [batch size, num_features * embedding dim]
1326+
return KeyedTensor(
1327+
keys=keyed_tensor.keys(),
1328+
values=mean_pooled_values,
1329+
length_per_key=keyed_tensor.length_per_key(),
1330+
key_dim=1,
1331+
)
1332+
1333+
return _apply_mean_pooling

0 commit comments

Comments
 (0)