Skip to content

Commit ff18e89

Browse files
joshuadengfacebook-github-bot
authored andcommitted
enable VBE for data parallel sharding (#2093)
Summary: Pull Request resolved: #2093 TBE now supports VBE for dense kernel Reviewed By: sarckk Differential Revision: D57745137 fbshipit-source-id: 3456c29ce06f4a95ee2e4db8fdb933df372288cf
1 parent 66f104f commit ff18e89

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,11 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
776776
if weights is not None and not torch.is_floating_point(weights):
777777
weights = None
778778
if features.variable_stride_per_key() and isinstance(
779-
self.emb_module, SplitTableBatchedEmbeddingBagsCodegen
779+
self.emb_module,
780+
(
781+
SplitTableBatchedEmbeddingBagsCodegen,
782+
DenseTableBatchedEmbeddingBagsCodegen,
783+
),
780784
):
781785
return self.emb_module(
782786
indices=features.values().long(),

torchrec/distributed/sharding/dp_sharding.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,6 @@ def forward(
153153
Awaitable[Awaitable[SparseFeatures]]: awaitable of awaitable of SparseFeatures.
154154
"""
155155

156-
if sparse_features.variable_stride_per_key():
157-
raise ValueError(
158-
"Dense TBE kernel does not support variable batch per feature"
159-
)
160156
return NoWait(cast(Awaitable[KeyedJaggedTensor], NoWait(sparse_features)))
161157

162158

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -245,15 +245,11 @@ def test_sharding_rw(
245245
SharderType.EMBEDDING_BAG_COLLECTION.value,
246246
]
247247
),
248-
kernel_type=st.sampled_from(
249-
[
250-
EmbeddingComputeKernel.DENSE.value,
251-
],
252-
),
253-
apply_optimizer_in_backward_config=st.sampled_from([None]),
248+
kernel_type=st.just(EmbeddingComputeKernel.DENSE.value),
249+
apply_optimizer_in_backward_config=st.just(None),
254250
# TODO - need to enable optimizer overlapped behavior for data_parallel tables
255251
)
256-
@settings(verbosity=Verbosity.verbose, max_examples=3, deadline=None)
252+
@settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None)
257253
def test_sharding_dp(
258254
self,
259255
sharder_type: str,
@@ -591,12 +587,13 @@ def test_sharding_twrw(
591587
ShardingType.TABLE_WISE.value,
592588
ShardingType.COLUMN_WISE.value,
593589
ShardingType.ROW_WISE.value,
590+
ShardingType.DATA_PARALLEL.value,
594591
]
595592
),
596593
global_constant_batch=st.booleans(),
597594
pooling=st.sampled_from([PoolingType.SUM, PoolingType.MEAN]),
598595
)
599-
@settings(verbosity=Verbosity.verbose, max_examples=6, deadline=None)
596+
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
600597
def test_sharding_variable_batch(
601598
self,
602599
sharding_type: str,
@@ -608,13 +605,18 @@ def test_sharding_variable_batch(
608605
self.skipTest(
609606
"bounds_check_indices on CPU does not support variable length (batch size)"
610607
)
608+
kernel = (
609+
EmbeddingComputeKernel.DENSE.value
610+
if sharding_type == ShardingType.DATA_PARALLEL.value
611+
else EmbeddingComputeKernel.FUSED.value
612+
)
611613
self._test_sharding(
612614
# pyre-ignore[6]
613615
sharders=[
614616
create_test_sharder(
615617
sharder_type=SharderType.EMBEDDING_BAG_COLLECTION.value,
616618
sharding_type=sharding_type,
617-
kernel_type=EmbeddingComputeKernel.FUSED.value,
619+
kernel_type=kernel,
618620
device=self.device,
619621
),
620622
],

0 commit comments

Comments
 (0)