Skip to content

Commit 1d46d90

Browse files
joshuadengfacebook-github-bot
authored andcommitted
Add VBE KJT support to EmbeddingCollection (#2047)
Summary: - pad VBE kjt lengths to final batch size so that it's compatible with EC kernel. - works with index dedup - expands embeddings with vbe inverse indices - remove sync from keyed jagged index select for permute - long term solution is to fix seq TBE to not need lengths/batch size info, just length per key Differential Revision: D51600051
1 parent df78731 commit 1d46d90

File tree

4 files changed

+225
-32
lines changed

4 files changed

+225
-32
lines changed

torchrec/distributed/embedding.py

Lines changed: 129 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import logging
1313
import warnings
1414
from collections import defaultdict, deque, OrderedDict
15-
from dataclasses import dataclass, field
1615
from itertools import accumulate
1716
from typing import Any, cast, Dict, List, MutableMapping, Optional, Tuple, Type, Union
1817

@@ -72,10 +71,15 @@
7271
EmbeddingCollection,
7372
EmbeddingCollectionInterface,
7473
)
75-
from torchrec.modules.utils import construct_jagged_tensors
74+
from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext
7675
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
7776
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
78-
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
77+
from torchrec.sparse.jagged_tensor import (
78+
_pin_and_move,
79+
_to_offsets,
80+
JaggedTensor,
81+
KeyedJaggedTensor,
82+
)
7983

8084
try:
8185
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -323,6 +327,30 @@ def create_sharding_infos_by_sharding_device_group(
323327
return sharding_type_device_group_to_sharding_infos
324328

325329

330+
def pad_vbe_kjt_lengths(features: KeyedJaggedTensor) -> KeyedJaggedTensor:
331+
max_stride = max(features.stride_per_key())
332+
new_lengths = torch.zeros(
333+
max_stride * len(features.keys()),
334+
device=features.device(),
335+
dtype=features.lengths().dtype,
336+
)
337+
cum_stride = 0
338+
for i, stride in enumerate(features.stride_per_key()):
339+
new_lengths[i * max_stride : i * max_stride + stride] = features.lengths()[
340+
cum_stride : cum_stride + stride
341+
]
342+
cum_stride += stride
343+
344+
return KeyedJaggedTensor(
345+
keys=features.keys(),
346+
values=features.values(),
347+
lengths=new_lengths,
348+
stride=max_stride,
349+
length_per_key=features.length_per_key(),
350+
offset_per_key=features.offset_per_key(),
351+
)
352+
353+
326354
class EmbeddingCollectionContext(Multistreamable):
327355
# Torch Dynamo does not support default_factory=list:
328356
# https://github.com/pytorch/pytorch/issues/120108
@@ -333,11 +361,13 @@ def __init__(
333361
sharding_contexts: Optional[List[SequenceShardingContext]] = None,
334362
input_features: Optional[List[KeyedJaggedTensor]] = None,
335363
reverse_indices: Optional[List[torch.Tensor]] = None,
364+
seq_vbe_ctx: Optional[List[SequenceVBEContext]] = None,
336365
) -> None:
337366
super().__init__()
338367
self.sharding_contexts: List[SequenceShardingContext] = sharding_contexts or []
339368
self.input_features: List[KeyedJaggedTensor] = input_features or []
340369
self.reverse_indices: List[torch.Tensor] = reverse_indices or []
370+
self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or []
341371

342372
def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
343373
for ctx in self.sharding_contexts:
@@ -346,6 +376,8 @@ def record_stream(self, stream: torch.cuda.streams.Stream) -> None:
346376
f.record_stream(stream)
347377
for r in self.reverse_indices:
348378
r.record_stream(stream)
379+
for s in self.seq_vbe_ctx:
380+
s.record_stream(stream)
349381

350382

351383
class EmbeddingCollectionAwaitable(LazyAwaitable[Dict[str, JaggedTensor]]):
@@ -385,6 +417,9 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
385417
if i >= len(self._ctx.reverse_indices)
386418
else self._ctx.reverse_indices[i]
387419
)
420+
seq_vbe_ctx = (
421+
None if i >= len(self._ctx.seq_vbe_ctx) else self._ctx.seq_vbe_ctx[i]
422+
)
388423
jt_dict.update(
389424
construct_jagged_tensors(
390425
embeddings=w.wait(),
@@ -394,6 +429,7 @@ def _wait_impl(self) -> Dict[str, JaggedTensor]:
394429
features_to_permute_indices=self._features_to_permute_indices,
395430
original_features=original_features,
396431
reverse_indices=reverse_indices,
432+
seq_vbe_ctx=seq_vbe_ctx,
397433
)
398434
)
399435
return jt_dict
@@ -506,6 +542,7 @@ def __init__(
506542
module.embedding_configs(), table_name_to_parameter_sharding
507543
)
508544
self._need_indices: bool = module.need_indices()
545+
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None
509546

510547
for index, (sharding, lookup) in enumerate(
511548
zip(
@@ -847,11 +884,9 @@ def _create_output_dist(
847884

848885
def _dedup_indices(
849886
self,
850-
input_feature_splits: List[KeyedJaggedTensor],
851887
ctx: EmbeddingCollectionContext,
888+
input_feature_splits: List[KeyedJaggedTensor],
852889
) -> List[KeyedJaggedTensor]:
853-
if not self._use_index_dedup:
854-
return input_feature_splits
855890
with record_function("## dedup_ec_indices ##"):
856891
features_by_shards = []
857892
for i, input_feature in enumerate(input_feature_splits):
@@ -881,30 +916,107 @@ def _dedup_indices(
881916

882917
return features_by_shards
883918

919+
def _create_inverse_indices_permute_per_sharding(
920+
self, inverse_indices: Tuple[List[str], torch.Tensor]
921+
) -> None:
922+
if (
923+
len(self._embedding_names_per_sharding) == 1
924+
and self._embedding_names_per_sharding[0] == inverse_indices[0]
925+
):
926+
return
927+
index_per_name = {name: i for i, name in enumerate(inverse_indices[0])}
928+
permute_per_sharding = []
929+
for emb_names in self._embedding_names_per_sharding:
930+
permute = _pin_and_move(
931+
torch.tensor(
932+
[index_per_name[name.split("@")[0]] for name in emb_names]
933+
),
934+
inverse_indices[1].device,
935+
)
936+
permute_per_sharding.append(permute)
937+
self._inverse_indices_permute_per_sharding = permute_per_sharding
938+
939+
def _compute_sequence_vbe_context(
940+
self,
941+
ctx: EmbeddingCollectionContext,
942+
unpadded_features: KeyedJaggedTensor,
943+
) -> None:
944+
assert (
945+
unpadded_features.inverse_indices_or_none() is not None
946+
), "inverse indices must be provided from KJT if using variable batch size per feature."
947+
948+
inverse_indices = unpadded_features.inverse_indices()
949+
stride = inverse_indices[1].numel() // len(inverse_indices[0])
950+
if self._inverse_indices_permute_per_sharding is None:
951+
self._create_inverse_indices_permute_per_sharding(inverse_indices)
952+
953+
if self._features_order:
954+
unpadded_features = unpadded_features.permute(
955+
self._features_order,
956+
self._features_order_tensor,
957+
)
958+
959+
features_by_sharding = unpadded_features.split(self._feature_splits)
960+
for i, feature in enumerate(features_by_sharding):
961+
if self._inverse_indices_permute_per_sharding is not None:
962+
permute = self._inverse_indices_permute_per_sharding[i]
963+
permuted_indices = torch.index_select(inverse_indices[1], 0, permute)
964+
else:
965+
permuted_indices = inverse_indices[1]
966+
stride_per_key = _pin_and_move(
967+
torch.tensor(feature.stride_per_key()), feature.device()
968+
)
969+
offsets = _to_offsets(stride_per_key)[:-1].unsqueeze(-1)
970+
recat = (permuted_indices + offsets).flatten().int()
971+
972+
if self._need_indices:
973+
reindexed_lengths, reindexed_values, _ = (
974+
torch.ops.fbgemm.permute_1D_sparse_data(
975+
recat,
976+
feature.lengths(),
977+
feature.values(),
978+
)
979+
)
980+
else:
981+
reindexed_lengths = torch.index_select(feature.lengths(), 0, recat)
982+
reindexed_values = None
983+
984+
reindexed_lengths = reindexed_lengths.view(-1, stride)
985+
reindexed_length_per_key = torch.sum(reindexed_lengths, dim=1).tolist()
986+
987+
ctx.seq_vbe_ctx.append(
988+
SequenceVBEContext(
989+
recat=recat,
990+
unpadded_lengths=feature.lengths(),
991+
reindexed_lengths=reindexed_lengths,
992+
reindexed_length_per_key=reindexed_length_per_key,
993+
reindexed_values=reindexed_values,
994+
)
995+
)
996+
884997
# pyre-ignore [14]
885998
def input_dist(
886999
self,
8871000
ctx: EmbeddingCollectionContext,
8881001
features: KeyedJaggedTensor,
8891002
) -> Awaitable[Awaitable[KJTList]]:
890-
if features.variable_stride_per_key():
891-
raise ValueError(
892-
"Variable batch per feature is not supported with EmbeddingCollection"
893-
)
8941003
if self._has_uninitialized_input_dist:
8951004
self._create_input_dist(input_feature_names=features.keys())
8961005
self._has_uninitialized_input_dist = False
8971006
with torch.no_grad():
1007+
unpadded_features = None
1008+
if features.variable_stride_per_key():
1009+
unpadded_features = features
1010+
features = pad_vbe_kjt_lengths(unpadded_features)
1011+
8981012
if self._features_order:
8991013
features = features.permute(
9001014
self._features_order,
9011015
self._features_order_tensor,
9021016
)
903-
904-
input_feature_splits = features.split(
905-
self._feature_splits,
906-
)
907-
features_by_shards = self._dedup_indices(input_feature_splits, ctx)
1017+
features_by_shards = features.split(self._feature_splits)
1018+
if self._use_index_dedup:
1019+
features_by_shards = self._dedup_indices(ctx, features_by_shards)
9081020

9091021
awaitables = []
9101022
for input_dist, features in zip(self._input_dists, features_by_shards):
@@ -919,6 +1031,8 @@ def input_dist(
9191031
),
9201032
)
9211033
)
1034+
if unpadded_features is not None:
1035+
self._compute_sequence_vbe_context(ctx, unpadded_features)
9221036
return KJTListSplitsAwaitable(awaitables, ctx)
9231037

9241038
def compute(

torchrec/distributed/tests/test_sequence_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def __init__(
305305
kernel_type: str,
306306
qcomms_config: Optional[QCommsConfig] = None,
307307
fused_params: Optional[Dict[str, Any]] = None,
308+
use_index_dedup: bool = False,
308309
) -> None:
309310
self._sharding_type = sharding_type
310311
self._kernel_type = kernel_type
@@ -321,6 +322,7 @@ def __init__(
321322
super().__init__(
322323
fused_params=fused_params,
323324
qcomm_codecs_registry=qcomm_codecs_registry,
325+
use_index_dedup=use_index_dedup,
324326
)
325327

326328
"""

torchrec/distributed/tests/test_sequence_model_parallel.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,9 @@ class SequenceModelParallelTest(MultiProcessTestBase):
6262
},
6363
]
6464
),
65+
variable_batch_size=st.booleans(),
6566
)
66-
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
67+
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
6768
def test_sharding_nccl_rw(
6869
self,
6970
sharding_type: str,
@@ -72,6 +73,7 @@ def test_sharding_nccl_rw(
7273
apply_optimizer_in_backward_config: Optional[
7374
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
7475
],
76+
variable_batch_size: bool,
7577
) -> None:
7678
assume(
7779
apply_optimizer_in_backward_config is None
@@ -88,6 +90,7 @@ def test_sharding_nccl_rw(
8890
backend="nccl",
8991
qcomms_config=qcomms_config,
9092
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
93+
variable_batch_size=variable_batch_size,
9194
)
9295

9396
@unittest.skipIf(
@@ -152,8 +155,9 @@ def test_sharding_nccl_dp(
152155
},
153156
]
154157
),
158+
variable_batch_size=st.booleans(),
155159
)
156-
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
160+
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
157161
def test_sharding_nccl_tw(
158162
self,
159163
sharding_type: str,
@@ -162,6 +166,7 @@ def test_sharding_nccl_tw(
162166
apply_optimizer_in_backward_config: Optional[
163167
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
164168
],
169+
variable_batch_size: bool,
165170
) -> None:
166171
assume(
167172
apply_optimizer_in_backward_config is None
@@ -178,7 +183,7 @@ def test_sharding_nccl_tw(
178183
backend="nccl",
179184
qcomms_config=qcomms_config,
180185
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
181-
variable_batch_size=False,
186+
variable_batch_size=variable_batch_size,
182187
)
183188

184189
@unittest.skipIf(
@@ -203,15 +208,17 @@ def test_sharding_nccl_tw(
203208
},
204209
]
205210
),
211+
variable_batch_size=st.booleans(),
206212
)
207-
@settings(verbosity=Verbosity.verbose, max_examples=2, deadline=None)
213+
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
208214
def test_sharding_nccl_cw(
209215
self,
210216
sharding_type: str,
211217
kernel_type: str,
212218
apply_optimizer_in_backward_config: Optional[
213219
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
214220
],
221+
variable_batch_size: bool,
215222
) -> None:
216223
assume(
217224
apply_optimizer_in_backward_config is None
@@ -230,7 +237,7 @@ def test_sharding_nccl_cw(
230237
for table in self.tables
231238
},
232239
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
233-
variable_batch_size=False,
240+
variable_batch_size=variable_batch_size,
234241
)
235242

236243
@unittest.skipIf(
@@ -246,25 +253,28 @@ def test_sharding_nccl_cw(
246253
ShardingType.ROW_WISE.value,
247254
]
248255
),
256+
index_dedup=st.booleans(),
249257
)
250-
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
258+
@settings(verbosity=Verbosity.verbose, max_examples=5, deadline=None)
251259
def test_sharding_variable_batch(
252260
self,
253261
sharding_type: str,
262+
index_dedup: bool,
254263
) -> None:
255264
self._test_sharding(
256265
sharders=[
257266
TestEmbeddingCollectionSharder(
258267
sharding_type=sharding_type,
259268
kernel_type=EmbeddingComputeKernel.FUSED.value,
269+
use_index_dedup=index_dedup,
260270
)
261271
],
262272
backend="nccl",
263273
constraints={
264274
table.name: ParameterConstraints(min_partition=4)
265275
for table in self.tables
266276
},
267-
variable_batch_size=True,
277+
variable_batch_per_feature=True,
268278
)
269279

270280
# pyre-fixme[56]
@@ -347,6 +357,7 @@ def _test_sharding(
347357
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
348358
] = None,
349359
variable_batch_size: bool = False,
360+
variable_batch_per_feature: bool = False,
350361
) -> None:
351362
self._run_multi_process_test(
352363
callable=sharding_single_rank_test,
@@ -362,4 +373,6 @@ def _test_sharding(
362373
qcomms_config=qcomms_config,
363374
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
364375
variable_batch_size=variable_batch_size,
376+
variable_batch_per_feature=variable_batch_per_feature,
377+
global_constant_batch=True,
365378
)

0 commit comments

Comments
 (0)