46
46
PartiallyMaterializedTensor ,
47
47
)
48
48
from torch import nn
49
+ from torch .distributed ._tensor import DTensor , Replicate , Shard as DTensorShard
49
50
from torchrec .distributed .comm import get_local_rank , get_node_group_size
50
51
from torchrec .distributed .composable .table_batched_embedding_slice import (
51
52
TableBatchedEmbeddingSlice ,
52
53
)
53
54
from torchrec .distributed .embedding_kernel import BaseEmbedding , get_state_dict
54
55
from torchrec .distributed .embedding_types import (
55
56
compute_kernel_to_embedding_location ,
57
+ DTensorMetadata ,
56
58
GroupedEmbeddingConfig ,
57
59
)
60
+ from torchrec .distributed .shards_wrapper import LocalShardsWrapper
58
61
from torchrec .distributed .types import (
59
62
Shard ,
60
63
ShardedTensor ,
@@ -213,6 +216,7 @@ class ShardParams:
213
216
optimizer_states : List [Optional [Tuple [torch .Tensor ]]]
214
217
local_metadata : List [ShardMetadata ]
215
218
embedding_weights : List [torch .Tensor ]
219
+ dtensor_metadata : List [DTensorMetadata ]
216
220
217
221
def get_optimizer_single_value_shard_metadata_and_global_metadata (
218
222
table_global_metadata : ShardedTensorMetadata ,
@@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
389
393
continue
390
394
if table_config .name not in table_to_shard_params :
391
395
table_to_shard_params [table_config .name ] = ShardParams (
392
- optimizer_states = [], local_metadata = [], embedding_weights = []
396
+ optimizer_states = [],
397
+ local_metadata = [],
398
+ embedding_weights = [],
399
+ dtensor_metadata = [],
393
400
)
394
401
optimizer_state_values = None
395
402
if optimizer_states :
@@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
410
417
table_to_shard_params [table_config .name ].local_metadata .append (
411
418
local_metadata
412
419
)
420
+ table_to_shard_params [table_config .name ].dtensor_metadata .append (
421
+ table_config .dtensor_metadata
422
+ )
413
423
table_to_shard_params [table_config .name ].embedding_weights .append (weight )
414
424
415
425
seen_tables = set ()
@@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
474
484
# pyre-ignore
475
485
def get_sharded_optim_state (
476
486
momentum_idx : int , state_key : str
477
- ) -> ShardedTensor :
487
+ ) -> Union [ ShardedTensor , DTensor ] :
478
488
assert momentum_idx > 0
479
489
momentum_local_shards : List [Shard ] = []
480
490
optimizer_sharded_tensor_metadata : ShardedTensorMetadata
@@ -528,12 +538,41 @@ def get_sharded_optim_state(
528
538
)
529
539
)
530
540
531
- # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
532
- return ShardedTensor ._init_from_local_shards_and_global_metadata (
533
- local_shards = momentum_local_shards ,
534
- sharded_tensor_metadata = optimizer_sharded_tensor_metadata ,
535
- process_group = self ._pg ,
536
- )
541
+ # Convert optimizer state to DTensor if enabled
542
+ if table_config .dtensor_metadata :
543
+ # if rowwise state we do Shard(0), regardless of how the table is sharded
544
+ if optim_state .dim () == 1 :
545
+ stride = (1 ,)
546
+ placements = (
547
+ (Replicate (), DTensorShard (0 ))
548
+ if table_config .dtensor_metadata .mesh .ndim == 2
549
+ else (DTensorShard (0 ),)
550
+ )
551
+ else :
552
+ stride = table_config .dtensor_metadata .stride
553
+ placements = table_config .dtensor_metadata .placements
554
+
555
+ return DTensor .from_local (
556
+ local_tensor = LocalShardsWrapper (
557
+ local_shards = [x .tensor for x in momentum_local_shards ],
558
+ local_offsets = [ # pyre-ignore[6]
559
+ x .metadata .shard_offsets
560
+ for x in momentum_local_shards
561
+ ],
562
+ ),
563
+ device_mesh = table_config .dtensor_metadata .mesh ,
564
+ placements = placements ,
565
+ shape = optimizer_sharded_tensor_metadata .size ,
566
+ stride = stride ,
567
+ run_check = False ,
568
+ )
569
+ else :
570
+ # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
571
+ return ShardedTensor ._init_from_local_shards_and_global_metadata (
572
+ local_shards = momentum_local_shards ,
573
+ sharded_tensor_metadata = optimizer_sharded_tensor_metadata ,
574
+ process_group = self ._pg ,
575
+ )
537
576
538
577
num_states : int = min (
539
578
# pyre-ignore
0 commit comments