Skip to content

Commit d0cc9a6

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
Replacing ShardedTensor with DTensor for RW sharding (#2147)
Summary: Pull Request resolved: #2147 **This is the first part of migration TorchRec state dict checkpointing from ShardedTensor to DTensor. It sets up the necessary infra to support additional sharding schemes. The general approach is to keep ShardedTensor paths and remove them once all sharding types are supported on DTensor. This includes ShardingPlan and ShardedTensor dataclasses such as ShardedTensorMetadata. Those will be migrated in a separate diff with ParameterSharding** NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512 **This diff includes:** + LocalShardsWrapper torch.tensor subclass to be used with DTensor + Changes in TorchRec state_dict load and creation to use DTensor for Row Wise path in both EmbeddingCollection and EmbeddingBagCollection + Changes to DCP to support LocalShardsWrapper for saving and reading (WriteItems and ReadItems) + Added DTensor paths to callsites where ShardedTensors are expected. **LocalShardsWrapper supports the following torch ops:** + torch.ops._c10d_functional.all_gather_into_tensor.default + aten._to_copy.default + aten.view.default + aten.equal.default + aten.detach.default With extensibility to add more as required by use cases. See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach. Reviewed By: XilunWu Differential Revision: D54375878
1 parent eca606d commit d0cc9a6

File tree

12 files changed

+659
-76
lines changed

12 files changed

+659
-76
lines changed

torchrec/distributed/composable/tests/test_embedding.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
from hypothesis import given, settings, Verbosity
17+
from torch.distributed._tensor.api import DTensor
1718
from torch.distributed.optim import (
1819
_apply_optimizer_in_backward as apply_optimizer_in_backward,
1920
)
@@ -177,6 +178,8 @@ def _test_sharding( # noqa C901
177178
)
178179
if isinstance(sharded_state, ShardedTensor):
179180
sharded_state.gather(out=sharded_param)
181+
elif isinstance(sharded_state, DTensor):
182+
sharded_param = sharded_state.full_tensor()
180183
else:
181184
sharded_param = sharded_state
182185

torchrec/distributed/composable/tests/test_embeddingbag.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from hypothesis import assume, given, settings, Verbosity
21+
from torch.distributed._tensor.api import DTensor
2122
from torch.distributed.optim import (
2223
_apply_optimizer_in_backward as apply_optimizer_in_backward,
2324
)
@@ -238,7 +239,11 @@ def _test_sharding( # noqa C901
238239
if ctx.rank == 0
239240
else None
240241
)
241-
sharded_state.gather(out=out)
242+
if isinstance(sharded_state, DTensor):
243+
out = sharded_state.full_tensor()
244+
else:
245+
sharded_state.gather(out=out)
246+
242247
if ctx.rank == 0:
243248
torch.testing.assert_close(
244249
unsharded_state,

torchrec/distributed/composable/tests/test_fsdp.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch import nn
1717
from torch.distributed._composable import fully_shard
1818
from torch.distributed._shard.sharded_tensor import ShardedTensor
19+
from torch.distributed._tensor import DTensor
1920

2021
from torch.distributed.checkpoint import (
2122
FileSystemReader,
@@ -193,6 +194,10 @@ def _run( # noqa
193194
if not p.local_shards():
194195
continue
195196
p = p.local_tensor()
197+
if isinstance(p, DTensor):
198+
if not p.to_local().local_shards():
199+
continue
200+
p = p.to_local().local_shards()[0]
196201
p_sum += p.sum()
197202
p.zero_()
198203
assert p.sum() == 0
@@ -205,6 +210,10 @@ def _run( # noqa
205210
if not t.local_shards():
206211
continue
207212
t = t.local_tensor()
213+
if isinstance(t, DTensor):
214+
if not t.to_local().local_shards(): # pyre-ignore[16]
215+
continue
216+
t = t.to_local().local_shards()[0]
208217
o_sum += t.sum()
209218
t.zero_()
210219
assert t.sum() == 0
@@ -228,6 +237,10 @@ def _run( # noqa
228237
continue
229238
p = p.local_tensor()
230239
p_sum_loaded += p.sum()
240+
if isinstance(p, DTensor):
241+
if not p.to_local().local_shards():
242+
continue
243+
p = p.to_local().local_shards()[0]
231244
assert p_sum.allclose(p_sum_loaded)
232245

233246
o_sum_loaded = torch.zeros(1, device=ctx.device)
@@ -239,6 +252,10 @@ def _run( # noqa
239252
if not t.local_shards():
240253
continue
241254
t = t.local_tensor()
255+
if isinstance(t, DTensor):
256+
if not t.to_local().local_shards():
257+
continue
258+
t = t.to_local().local_shards()[0]
242259
o_sum_loaded += t.sum()
243260
assert o_sum.allclose(o_sum_loaded)
244261

torchrec/distributed/embedding.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from torch import distributed as dist, nn
3030
from torch.autograd.profiler import record_function
31+
from torch.distributed._tensor import DTensor
3132
from torch.nn.parallel import DistributedDataParallel
3233
from torchrec.distributed.embedding_sharding import (
3334
EmbeddingSharding,
@@ -55,6 +56,7 @@
5556
from torchrec.distributed.sharding.tw_sequence_sharding import (
5657
TwSequenceEmbeddingSharding,
5758
)
59+
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
5860
from torchrec.distributed.types import (
5961
Awaitable,
6062
EmbeddingModuleShardingPlan,
@@ -601,18 +603,20 @@ def _pre_load_state_dict_hook(
601603
) -> None:
602604
"""
603605
Modify the destination state_dict for model parallel
604-
to transform from ShardedTensors into tensors
606+
to transform from ShardedTensors/DTensors into tensors
605607
"""
606-
for (
607-
table_name,
608-
model_shards,
609-
) in self._model_parallel_name_to_local_shards.items():
608+
for table_name in self._model_parallel_name_to_local_shards.keys():
610609
key = f"{prefix}embeddings.{table_name}.weight"
611-
610+
# gather model shards from both DTensor and ShardedTensor maps
611+
model_shards_sharded_tensor = self._model_parallel_name_to_local_shards[
612+
table_name
613+
]
614+
model_shards_dtensor = self._model_parallel_name_to_shards_wrapper[
615+
table_name
616+
]
612617
# If state_dict[key] is already a ShardedTensor, use its local shards
613618
if isinstance(state_dict[key], ShardedTensor):
614619
local_shards = state_dict[key].local_shards()
615-
# If no local shards, create an empty tensor
616620
if len(local_shards) == 0:
617621
state_dict[key] = torch.empty(0)
618622
else:
@@ -624,27 +628,57 @@ def _pre_load_state_dict_hook(
624628
).view(-1, dim)
625629
else:
626630
state_dict[key] = local_shards[0].tensor.view(-1, dim)
627-
else:
631+
elif isinstance(state_dict[key], DTensor):
632+
shards_wrapper = state_dict[key].to_local()
633+
local_shards = shards_wrapper.local_shards()
634+
dim = shards_wrapper.local_sizes()[0][1]
635+
if len(local_shards) == 0:
636+
state_dict[key] = torch.empty(0)
637+
elif len(local_shards) > 1:
638+
# TODO - add multiple shards on rank support
639+
raise RuntimeError(
640+
f"Multiple shards on rank is not supported for DTensor yet, got {len(local_shards)}"
641+
)
642+
else:
643+
state_dict[key] = local_shards[0].view(-1, dim)
644+
elif isinstance(state_dict[key], torch.Tensor):
628645
local_shards = []
629-
for shard in model_shards:
630-
# Extract shard size and offsets for splicing
631-
shard_sizes = shard.metadata.shard_sizes
632-
shard_offsets = shard.metadata.shard_offsets
633-
634-
# Prepare tensor by splicing and placing on appropriate device
635-
spliced_tensor = state_dict[key][
636-
shard_offsets[0] : shard_offsets[0] + shard_sizes[0],
637-
shard_offsets[1] : shard_offsets[1] + shard_sizes[1],
638-
].to(shard.tensor.get_device())
639-
640-
# Append spliced tensor into local shards
641-
local_shards.append(spliced_tensor)
642-
646+
if model_shards_sharded_tensor:
647+
# splice according to sharded tensor metadata
648+
for shard in model_shards_sharded_tensor:
649+
# Extract shard size and offsets for splicing
650+
shard_size = shard.metadata.shard_sizes
651+
shard_offset = shard.metadata.shard_offsets
652+
653+
# Prepare tensor by splicing and placing on appropriate device
654+
spliced_tensor = state_dict[key][
655+
shard_offset[0] : shard_offset[0] + shard_size[0],
656+
shard_offset[1] : shard_offset[1] + shard_size[1],
657+
]
658+
659+
# Append spliced tensor into local shards
660+
local_shards.append(spliced_tensor)
661+
elif model_shards_dtensor:
662+
# splice according to dtensor metadata
663+
for tensor, shard_offset in zip(
664+
model_shards_dtensor["local_tensors"],
665+
model_shards_dtensor["local_offsets"],
666+
):
667+
shard_size = tensor.size()
668+
spliced_tensor = state_dict[key][
669+
shard_offset[0] : shard_offset[0] + shard_size[0],
670+
shard_offset[1] : shard_offset[1] + shard_size[1],
671+
]
672+
local_shards.append(spliced_tensor)
643673
state_dict[key] = (
644674
torch.empty(0)
645675
if not local_shards
646676
else torch.cat(local_shards, dim=0)
647677
)
678+
else:
679+
raise RuntimeError(
680+
f"Unexpected state_dict key type {type(state_dict[key])} found for {key}"
681+
)
648682

649683
for lookup in self._lookups:
650684
while isinstance(lookup, DistributedDataParallel):
@@ -661,7 +695,9 @@ def _initialize_torch_state(self) -> None: # noqa
661695
for table_name in self._table_names:
662696
self.embeddings[table_name] = nn.Module()
663697
self._model_parallel_name_to_local_shards = OrderedDict()
698+
self._model_parallel_name_to_shards_wrapper = OrderedDict()
664699
self._model_parallel_name_to_sharded_tensor = OrderedDict()
700+
self._model_parallel_name_to_dtensor = OrderedDict()
665701
model_parallel_name_to_compute_kernel: Dict[str, str] = {}
666702
for (
667703
table_name,
@@ -670,6 +706,9 @@ def _initialize_torch_state(self) -> None: # noqa
670706
if parameter_sharding.sharding_type == ShardingType.DATA_PARALLEL.value:
671707
continue
672708
self._model_parallel_name_to_local_shards[table_name] = []
709+
self._model_parallel_name_to_shards_wrapper[table_name] = OrderedDict(
710+
[("local_tensors", []), ("local_offsets", [])]
711+
)
673712
model_parallel_name_to_compute_kernel[table_name] = (
674713
parameter_sharding.compute_kernel
675714
)
@@ -691,18 +730,29 @@ def _initialize_torch_state(self) -> None: # noqa
691730
# save local_shards for transforming MP params to shardedTensor
692731
for key, v in lookup.state_dict().items():
693732
table_name = key[: -len(".weight")]
694-
self._model_parallel_name_to_local_shards[table_name].extend(
695-
v.local_shards()
696-
)
733+
if isinstance(v, DTensor):
734+
shards_wrapper = self._model_parallel_name_to_shards_wrapper[
735+
table_name
736+
]
737+
local_shards_wrapper = v._local_tensor
738+
shards_wrapper["local_tensors"].extend(local_shards_wrapper.local_shards()) # pyre-ignore[16]
739+
shards_wrapper["local_offsets"].extend(local_shards_wrapper.local_offsets()) # pyre-ignore[16]
740+
shards_wrapper["global_size"] = v.size()
741+
shards_wrapper["global_stride"] = v.stride()
742+
shards_wrapper["placements"] = v.placements
743+
elif isinstance(v, ShardedTensor):
744+
self._model_parallel_name_to_local_shards[table_name].extend(
745+
v.local_shards()
746+
)
697747
for (
698748
table_name,
699749
tbe_slice,
700750
) in lookup.named_parameters_by_table():
701751
self.embeddings[table_name].register_parameter("weight", tbe_slice)
702-
for (
703-
table_name,
704-
local_shards,
705-
) in self._model_parallel_name_to_local_shards.items():
752+
for table_name in self._model_parallel_name_to_local_shards.keys():
753+
local_shards = self._model_parallel_name_to_local_shards[table_name]
754+
shards_wrapper_map = self._model_parallel_name_to_shards_wrapper[table_name]
755+
706756
# for shards that don't exist on this rank, register with empty tensor
707757
if not hasattr(self.embeddings[table_name], "weight"):
708758
self.embeddings[table_name].register_parameter(
@@ -715,18 +765,34 @@ def _initialize_torch_state(self) -> None: # noqa
715765
self.embeddings[table_name].weight._in_backward_optimizers = [
716766
EmptyFusedOptimizer()
717767
]
768+
718769
if model_parallel_name_to_compute_kernel[table_name] in {
719770
EmbeddingComputeKernel.KEY_VALUE.value
720771
}:
721772
continue
722-
# created ShardedTensors once in init, use in post_state_dict_hook
723-
self._model_parallel_name_to_sharded_tensor[table_name] = (
724-
ShardedTensor._init_from_local_shards(
725-
local_shards,
726-
self._name_to_table_size[table_name],
727-
process_group=self._env.process_group,
773+
774+
if shards_wrapper_map["local_tensors"]:
775+
self._model_parallel_name_to_dtensor[table_name] = DTensor.from_local(
776+
local_tensor=LocalShardsWrapper(
777+
local_shards=shards_wrapper_map["local_tensors"],
778+
local_offsets=shards_wrapper_map["local_offsets"],
779+
),
780+
device_mesh=self._env.device_mesh,
781+
placements=shards_wrapper_map["placements"],
782+
shape=shards_wrapper_map["global_size"],
783+
stride=shards_wrapper_map["global_stride"],
784+
run_check=False,
785+
)
786+
else:
787+
# if DTensors for table do not exist, create ShardedTensor
788+
# created ShardedTensors once in init, use in post_state_dict_hook
789+
self._model_parallel_name_to_sharded_tensor[table_name] = (
790+
ShardedTensor._init_from_local_shards(
791+
local_shards,
792+
self._name_to_table_size[table_name],
793+
process_group=self._env.process_group,
794+
)
728795
)
729-
)
730796

731797
def post_state_dict_hook(
732798
module: ShardedEmbeddingCollection,
@@ -741,6 +807,12 @@ def post_state_dict_hook(
741807
) in module._model_parallel_name_to_sharded_tensor.items():
742808
destination_key = f"{prefix}embeddings.{table_name}.weight"
743809
destination[destination_key] = sharded_t
810+
for (
811+
table_name,
812+
d_tensor,
813+
) in module._model_parallel_name_to_dtensor.items():
814+
destination_key = f"{prefix}embeddings.{table_name}.weight"
815+
destination[destination_key] = d_tensor
744816

745817
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
746818
self._register_state_dict_hook(post_state_dict_hook)

torchrec/distributed/embedding_kernel.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
import torch
1616
import torch.distributed as dist
1717
from torch import nn
18+
from torch.distributed._tensor import DTensor
1819
from torchrec.distributed.embedding_types import (
20+
DTensorMetadata,
1921
EmbeddingComputeKernel,
2022
GroupedEmbeddingConfig,
2123
ShardedEmbeddingTable,
2224
)
25+
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
2326
from torchrec.distributed.types import Shard, ShardedTensor, ShardedTensorMetadata
2427
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
2528

@@ -73,6 +76,8 @@ def get_state_dict(
7376
"""
7477
key_to_local_shards: Dict[str, List[Shard]] = defaultdict(list)
7578
key_to_global_metadata: Dict[str, ShardedTensorMetadata] = {}
79+
key_to_dtensor_metadata: Dict[str, DTensorMetadata] = {}
80+
key_to_local_tensor_shards: Dict[str, List[Any]] = defaultdict(list) # pyre-ignore[33]
7681

7782
def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
7883
return prefix + f"{embedding_table.name}.weight"
@@ -98,7 +103,16 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
98103
if qscale is not None:
99104
assert embedding_table.local_cols == param.size(1) # pyre-ignore[16]
100105

101-
if embedding_table.global_metadata is not None and pg is not None:
106+
if embedding_table.dtensor_metadata is not None and pg is not None:
107+
# DTensor path
108+
key_to_dtensor_metadata[key] = embedding_table.dtensor_metadata
109+
key_to_local_tensor_shards[key].append(
110+
[
111+
param,
112+
embedding_table.local_metadata.shard_offsets, # pyre-ignore[16]
113+
]
114+
)
115+
elif embedding_table.global_metadata is not None and pg is not None:
102116
# set additional field of sharded tensor based on local tensor properties
103117
embedding_table.global_metadata.tensor_properties.dtype = (
104118
param.dtype # pyre-ignore[16]
@@ -133,5 +147,24 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
133147
process_group=pg,
134148
)
135149
)
136-
150+
# DTensor path
151+
for key in key_to_local_tensor_shards:
152+
dtensor_metadata = key_to_dtensor_metadata[key]
153+
destination[key] = DTensor.from_local(
154+
local_tensor=LocalShardsWrapper(
155+
local_shards=[
156+
tensor_shards[0]
157+
for tensor_shards in key_to_local_tensor_shards[key]
158+
],
159+
local_offsets=[
160+
tensor_shards[1]
161+
for tensor_shards in key_to_local_tensor_shards[key]
162+
],
163+
),
164+
device_mesh=dtensor_metadata.mesh,
165+
placements=dtensor_metadata.placements,
166+
shape=torch.Size(dtensor_metadata.size), # pyre-ignore[6]
167+
stride=dtensor_metadata.stride,
168+
run_check=False,
169+
)
137170
return destination

0 commit comments

Comments
 (0)