Skip to content

Commit 6cb7df9

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add DTensor to optimizer state dict (#2585)
Summary: To support 2D parallelism checkpointing, we introduce DTensor to the optimizer state dict. It is enabled through fused_params["output_dtensor"] = True, meaning when table shards are outputted in DTensor so are optimizer shards. This diff allows us to leverage N-dimensional device meshes with support for abritrary replication/sharding groups - making checkpointing easy as DCP/Modelstore support replicated/sharded placements on a device mesh (something that is unsupported in ShardedTensor) Differential Revision: D65555455
1 parent fad795e commit 6cb7df9

File tree

3 files changed

+106
-22
lines changed

3 files changed

+106
-22
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,18 @@
4646
PartiallyMaterializedTensor,
4747
)
4848
from torch import nn
49+
from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard
4950
from torchrec.distributed.comm import get_local_rank, get_node_group_size
5051
from torchrec.distributed.composable.table_batched_embedding_slice import (
5152
TableBatchedEmbeddingSlice,
5253
)
5354
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
5455
from torchrec.distributed.embedding_types import (
5556
compute_kernel_to_embedding_location,
57+
DTensorMetadata,
5658
GroupedEmbeddingConfig,
5759
)
60+
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
5861
from torchrec.distributed.types import (
5962
Shard,
6063
ShardedTensor,
@@ -213,6 +216,7 @@ class ShardParams:
213216
optimizer_states: List[Optional[Tuple[torch.Tensor]]]
214217
local_metadata: List[ShardMetadata]
215218
embedding_weights: List[torch.Tensor]
219+
dtensor_metadata: List[DTensorMetadata]
216220

217221
def get_optimizer_single_value_shard_metadata_and_global_metadata(
218222
table_global_metadata: ShardedTensorMetadata,
@@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
389393
continue
390394
if table_config.name not in table_to_shard_params:
391395
table_to_shard_params[table_config.name] = ShardParams(
392-
optimizer_states=[], local_metadata=[], embedding_weights=[]
396+
optimizer_states=[],
397+
local_metadata=[],
398+
embedding_weights=[],
399+
dtensor_metadata=[],
393400
)
394401
optimizer_state_values = None
395402
if optimizer_states:
@@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
410417
table_to_shard_params[table_config.name].local_metadata.append(
411418
local_metadata
412419
)
420+
table_to_shard_params[table_config.name].dtensor_metadata.append(
421+
table_config.dtensor_metadata
422+
)
413423
table_to_shard_params[table_config.name].embedding_weights.append(weight)
414424

415425
seen_tables = set()
@@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
474484
# pyre-ignore
475485
def get_sharded_optim_state(
476486
momentum_idx: int, state_key: str
477-
) -> ShardedTensor:
487+
) -> Union[ShardedTensor, DTensor]:
478488
assert momentum_idx > 0
479489
momentum_local_shards: List[Shard] = []
480490
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
@@ -528,12 +538,41 @@ def get_sharded_optim_state(
528538
)
529539
)
530540

531-
# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
532-
return ShardedTensor._init_from_local_shards_and_global_metadata(
533-
local_shards=momentum_local_shards,
534-
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
535-
process_group=self._pg,
536-
)
541+
# Convert optimizer state to DTensor if enabled
542+
if table_config.dtensor_metadata:
543+
# if rowwise state we do Shard(0), regardless of how the table is sharded
544+
if optim_state.dim() == 1:
545+
stride = (1,)
546+
placements = (
547+
(Replicate(), DTensorShard(0))
548+
if table_config.dtensor_metadata.mesh.ndim == 2
549+
else (DTensorShard(0),)
550+
)
551+
else:
552+
stride = table_config.dtensor_metadata.stride
553+
placements = table_config.dtensor_metadata.placements
554+
555+
return DTensor.from_local(
556+
local_tensor=LocalShardsWrapper(
557+
local_shards=[x.tensor for x in momentum_local_shards],
558+
local_offsets=[ # pyre-ignore[6]
559+
x.metadata.shard_offsets
560+
for x in momentum_local_shards
561+
],
562+
),
563+
device_mesh=table_config.dtensor_metadata.mesh,
564+
placements=placements,
565+
shape=optimizer_sharded_tensor_metadata.size,
566+
stride=stride,
567+
run_check=False,
568+
)
569+
else:
570+
# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
571+
return ShardedTensor._init_from_local_shards_and_global_metadata(
572+
local_shards=momentum_local_shards,
573+
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
574+
process_group=self._pg,
575+
)
537576

538577
num_states: int = min(
539578
# pyre-ignore

torchrec/distributed/shards_wrapper.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,15 @@ def __new__(
6868

6969
# we calculate the total tensor size by "concat" on second tensor dimension
7070
cat_tensor_shape = list(local_shards[0].size())
71-
if len(local_shards) > 1: # column-wise sharding
71+
if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding
7272
for shard in local_shards[1:]:
7373
cat_tensor_shape[1] += shard.size()[1]
7474

75+
# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
76+
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
77+
for shard in local_shards[1:]:
78+
cat_tensor_shape[0] += shard.size()[0]
79+
7580
wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
7681
wrapper_shape = torch.Size(cat_tensor_shape)
7782
chunks_meta = [
@@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
110115
aten.equal.default: cls.handle_equal,
111116
aten.detach.default: cls.handle_detach,
112117
aten.clone.default: cls.handle_clone,
118+
aten.new_empty.default: cls.handle_new_empty,
113119
}
114120

115121
if func in dispatcher:
@@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs):
153159
def handle_view(args, kwargs):
154160
view_shape = args[1]
155161
res_shards_list = []
156-
if (
157-
len(args[0].local_shards()) > 1
158-
and args[0].storage_metadata().size[0] == view_shape[0]
159-
and args[0].storage_metadata().size[1] == view_shape[1]
160-
):
161-
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
162-
# init calls view_as() on the global tensor shape
163-
# will fail because the view shape is not applicable to individual shards.
164-
res_shards_list = [
165-
aten.view.default(shard, shard.shape, **kwargs)
166-
for shard in args[0].local_shards()
167-
]
162+
if len(args[0].local_shards()) > 1:
163+
if args[0].local_shards()[0].ndim == 2:
164+
assert (
165+
args[0].storage_metadata().size[0] == view_shape[0]
166+
and args[0].storage_metadata().size[1] == view_shape[1]
167+
)
168+
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
169+
# init calls view_as() on the global tensor shape
170+
# will fail because the view shape is not applicable to individual shards.
171+
res_shards_list = [
172+
aten.view.default(shard, shard.shape, **kwargs)
173+
for shard in args[0].local_shards()
174+
]
175+
elif args[0].local_shards()[0].ndim == 1:
176+
assert args[0].storage_metadata().size[0] == view_shape[0]
177+
# This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded
178+
res_shards_list = [
179+
aten.view.default(shard, shard.shape, **kwargs)
180+
for shard in args[0].local_shards()
181+
]
182+
else:
183+
raise NotImplementedError("No support for view on tensors ndim > 2")
168184
else:
169185
# view is called per shard
170186
res_shards_list = [
@@ -220,6 +236,16 @@ def handle_clone(args, kwargs):
220236
]
221237
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
222238

239+
@staticmethod
240+
# pyre-fixme[3]: Return type must be annotated.
241+
# pyre-fixme[2]: Parameter must be annotated.
242+
def handle_new_empty(args, kwargs):
243+
self_ls = args[0]
244+
return LocalShardsWrapper(
245+
[torch.empty_like(shard) for shard in self_ls._local_shards],
246+
self_ls.local_offsets(),
247+
)
248+
223249
@property
224250
def device(self) -> torch._C.device: # type: ignore[override]
225251
return (

torchrec/optim/keyed.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from torch import optim
2929
from torch.distributed._shard.sharded_tensor import ShardedTensor
30-
30+
from torch.distributed.tensor import DTensor
3131

3232
OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer]
3333

@@ -111,6 +111,8 @@ def _update_param_state_dict_object(
111111
param_state_dict_to_load: Dict[str, Any],
112112
parent_keys: List[Union[str, int, float, bool, None]],
113113
) -> None:
114+
# Import at function level to avoid circular dependency.
115+
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
114116

115117
for k, v in current_param_state_dict.items():
116118
new_v = param_state_dict_to_load[k]
@@ -134,6 +136,23 @@ def _update_param_state_dict_object(
134136
)
135137
for shard, new_shard in zip(v.local_shards(), new_v.local_shards()):
136138
shard.tensor.detach().copy_(new_shard.tensor)
139+
elif isinstance(v, DTensor):
140+
assert isinstance(new_v, DTensor)
141+
if isinstance(v.to_local(), LocalShardsWrapper):
142+
assert isinstance(new_v.to_local(), LocalShardsWrapper)
143+
num_shards = len(v.to_local().local_shards()) # pyre-ignore[16]
144+
num_new_shards = len(new_v.to_local().local_shards())
145+
if num_shards != num_new_shards:
146+
raise ValueError(
147+
f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}"
148+
)
149+
for shard, new_shard in zip(
150+
v.to_local().local_shards(), new_v.to_local().local_shards()
151+
):
152+
shard.detach().copy_(new_shard)
153+
else:
154+
assert isinstance(new_v.to_local(), torch.Tensor)
155+
v.detach().copy_(new_v)
137156
elif isinstance(v, torch.Tensor):
138157
v.detach().copy_(new_v)
139158
else:

0 commit comments

Comments
 (0)