28
28
import torch
29
29
from torch import distributed as dist , nn
30
30
from torch .autograd .profiler import record_function
31
+ from torch .distributed ._tensor import DTensor
31
32
from torch .nn .parallel import DistributedDataParallel
32
33
from torchrec .distributed .embedding_sharding import (
33
34
EmbeddingSharding ,
55
56
from torchrec .distributed .sharding .tw_sequence_sharding import (
56
57
TwSequenceEmbeddingSharding ,
57
58
)
59
+ from torchrec .distributed .shards_wrapper import LocalShardsWrapper
58
60
from torchrec .distributed .types import (
59
61
Awaitable ,
60
62
EmbeddingModuleShardingPlan ,
@@ -601,18 +603,20 @@ def _pre_load_state_dict_hook(
601
603
) -> None :
602
604
"""
603
605
Modify the destination state_dict for model parallel
604
- to transform from ShardedTensors into tensors
606
+ to transform from ShardedTensors/DTensors into tensors
605
607
"""
606
- for (
607
- table_name ,
608
- model_shards ,
609
- ) in self ._model_parallel_name_to_local_shards .items ():
608
+ for table_name in self ._model_parallel_name_to_local_shards .keys ():
610
609
key = f"{ prefix } embeddings.{ table_name } .weight"
611
-
610
+ # gather model shards from both DTensor and ShardedTensor maps
611
+ model_shards_sharded_tensor = self ._model_parallel_name_to_local_shards [
612
+ table_name
613
+ ]
614
+ model_shards_dtensor = self ._model_parallel_name_to_shards_wrapper [
615
+ table_name
616
+ ]
612
617
# If state_dict[key] is already a ShardedTensor, use its local shards
613
618
if isinstance (state_dict [key ], ShardedTensor ):
614
619
local_shards = state_dict [key ].local_shards ()
615
- # If no local shards, create an empty tensor
616
620
if len (local_shards ) == 0 :
617
621
state_dict [key ] = torch .empty (0 )
618
622
else :
@@ -624,27 +628,57 @@ def _pre_load_state_dict_hook(
624
628
).view (- 1 , dim )
625
629
else :
626
630
state_dict [key ] = local_shards [0 ].tensor .view (- 1 , dim )
627
- else :
631
+ elif isinstance (state_dict [key ], DTensor ):
632
+ shards_wrapper = state_dict [key ].to_local ()
633
+ local_shards = shards_wrapper .local_shards ()
634
+ dim = shards_wrapper .local_sizes ()[0 ][1 ]
635
+ if len (local_shards ) == 0 :
636
+ state_dict [key ] = torch .empty (0 )
637
+ elif len (local_shards ) > 1 :
638
+ # TODO - add multiple shards on rank support
639
+ raise RuntimeError (
640
+ f"Multiple shards on rank is not supported for DTensor yet, got { len (local_shards )} "
641
+ )
642
+ else :
643
+ state_dict [key ] = local_shards [0 ].view (- 1 , dim )
644
+ elif isinstance (state_dict [key ], torch .Tensor ):
628
645
local_shards = []
629
- for shard in model_shards :
630
- # Extract shard size and offsets for splicing
631
- shard_sizes = shard .metadata .shard_sizes
632
- shard_offsets = shard .metadata .shard_offsets
633
-
634
- # Prepare tensor by splicing and placing on appropriate device
635
- spliced_tensor = state_dict [key ][
636
- shard_offsets [0 ] : shard_offsets [0 ] + shard_sizes [0 ],
637
- shard_offsets [1 ] : shard_offsets [1 ] + shard_sizes [1 ],
638
- ].to (shard .tensor .get_device ())
639
-
640
- # Append spliced tensor into local shards
641
- local_shards .append (spliced_tensor )
642
-
646
+ if model_shards_sharded_tensor :
647
+ # splice according to sharded tensor metadata
648
+ for shard in model_shards_sharded_tensor :
649
+ # Extract shard size and offsets for splicing
650
+ shard_size = shard .metadata .shard_sizes
651
+ shard_offset = shard .metadata .shard_offsets
652
+
653
+ # Prepare tensor by splicing and placing on appropriate device
654
+ spliced_tensor = state_dict [key ][
655
+ shard_offset [0 ] : shard_offset [0 ] + shard_size [0 ],
656
+ shard_offset [1 ] : shard_offset [1 ] + shard_size [1 ],
657
+ ]
658
+
659
+ # Append spliced tensor into local shards
660
+ local_shards .append (spliced_tensor )
661
+ elif model_shards_dtensor :
662
+ # splice according to dtensor metadata
663
+ for tensor , shard_offset in zip (
664
+ model_shards_dtensor ["local_tensors" ],
665
+ model_shards_dtensor ["local_offsets" ],
666
+ ):
667
+ shard_size = tensor .size ()
668
+ spliced_tensor = state_dict [key ][
669
+ shard_offset [0 ] : shard_offset [0 ] + shard_size [0 ],
670
+ shard_offset [1 ] : shard_offset [1 ] + shard_size [1 ],
671
+ ]
672
+ local_shards .append (spliced_tensor )
643
673
state_dict [key ] = (
644
674
torch .empty (0 )
645
675
if not local_shards
646
676
else torch .cat (local_shards , dim = 0 )
647
677
)
678
+ else :
679
+ raise RuntimeError (
680
+ f"Unexpected state_dict key type { type (state_dict [key ])} found for { key } "
681
+ )
648
682
649
683
for lookup in self ._lookups :
650
684
while isinstance (lookup , DistributedDataParallel ):
@@ -661,7 +695,9 @@ def _initialize_torch_state(self) -> None: # noqa
661
695
for table_name in self ._table_names :
662
696
self .embeddings [table_name ] = nn .Module ()
663
697
self ._model_parallel_name_to_local_shards = OrderedDict ()
698
+ self ._model_parallel_name_to_shards_wrapper = OrderedDict ()
664
699
self ._model_parallel_name_to_sharded_tensor = OrderedDict ()
700
+ self ._model_parallel_name_to_dtensor = OrderedDict ()
665
701
model_parallel_name_to_compute_kernel : Dict [str , str ] = {}
666
702
for (
667
703
table_name ,
@@ -670,6 +706,9 @@ def _initialize_torch_state(self) -> None: # noqa
670
706
if parameter_sharding .sharding_type == ShardingType .DATA_PARALLEL .value :
671
707
continue
672
708
self ._model_parallel_name_to_local_shards [table_name ] = []
709
+ self ._model_parallel_name_to_shards_wrapper [table_name ] = OrderedDict (
710
+ [("local_tensors" , []), ("local_offsets" , [])]
711
+ )
673
712
model_parallel_name_to_compute_kernel [table_name ] = (
674
713
parameter_sharding .compute_kernel
675
714
)
@@ -691,18 +730,29 @@ def _initialize_torch_state(self) -> None: # noqa
691
730
# save local_shards for transforming MP params to shardedTensor
692
731
for key , v in lookup .state_dict ().items ():
693
732
table_name = key [: - len (".weight" )]
694
- self ._model_parallel_name_to_local_shards [table_name ].extend (
695
- v .local_shards ()
696
- )
733
+ if isinstance (v , DTensor ):
734
+ shards_wrapper = self ._model_parallel_name_to_shards_wrapper [
735
+ table_name
736
+ ]
737
+ local_shards_wrapper = v ._local_tensor
738
+ shards_wrapper ["local_tensors" ].extend (local_shards_wrapper .local_shards ()) # pyre-ignore[16]
739
+ shards_wrapper ["local_offsets" ].extend (local_shards_wrapper .local_offsets ()) # pyre-ignore[16]
740
+ shards_wrapper ["global_size" ] = v .size ()
741
+ shards_wrapper ["global_stride" ] = v .stride ()
742
+ shards_wrapper ["placements" ] = v .placements
743
+ elif isinstance (v , ShardedTensor ):
744
+ self ._model_parallel_name_to_local_shards [table_name ].extend (
745
+ v .local_shards ()
746
+ )
697
747
for (
698
748
table_name ,
699
749
tbe_slice ,
700
750
) in lookup .named_parameters_by_table ():
701
751
self .embeddings [table_name ].register_parameter ("weight" , tbe_slice )
702
- for (
703
- table_name ,
704
- local_shards ,
705
- ) in self . _model_parallel_name_to_local_shards . items ():
752
+ for table_name in self . _model_parallel_name_to_local_shards . keys ():
753
+ local_shards = self . _model_parallel_name_to_local_shards [ table_name ]
754
+ shards_wrapper_map = self . _model_parallel_name_to_shards_wrapper [ table_name ]
755
+
706
756
# for shards that don't exist on this rank, register with empty tensor
707
757
if not hasattr (self .embeddings [table_name ], "weight" ):
708
758
self .embeddings [table_name ].register_parameter (
@@ -715,18 +765,34 @@ def _initialize_torch_state(self) -> None: # noqa
715
765
self .embeddings [table_name ].weight ._in_backward_optimizers = [
716
766
EmptyFusedOptimizer ()
717
767
]
768
+
718
769
if model_parallel_name_to_compute_kernel [table_name ] in {
719
770
EmbeddingComputeKernel .KEY_VALUE .value
720
771
}:
721
772
continue
722
- # created ShardedTensors once in init, use in post_state_dict_hook
723
- self ._model_parallel_name_to_sharded_tensor [table_name ] = (
724
- ShardedTensor ._init_from_local_shards (
725
- local_shards ,
726
- self ._name_to_table_size [table_name ],
727
- process_group = self ._env .process_group ,
773
+
774
+ if shards_wrapper_map ["local_tensors" ]:
775
+ self ._model_parallel_name_to_dtensor [table_name ] = DTensor .from_local (
776
+ local_tensor = LocalShardsWrapper (
777
+ local_shards = shards_wrapper_map ["local_tensors" ],
778
+ local_offsets = shards_wrapper_map ["local_offsets" ],
779
+ ),
780
+ device_mesh = self ._env .device_mesh ,
781
+ placements = shards_wrapper_map ["placements" ],
782
+ shape = shards_wrapper_map ["global_size" ],
783
+ stride = shards_wrapper_map ["global_stride" ],
784
+ run_check = False ,
785
+ )
786
+ else :
787
+ # if DTensors for table do not exist, create ShardedTensor
788
+ # created ShardedTensors once in init, use in post_state_dict_hook
789
+ self ._model_parallel_name_to_sharded_tensor [table_name ] = (
790
+ ShardedTensor ._init_from_local_shards (
791
+ local_shards ,
792
+ self ._name_to_table_size [table_name ],
793
+ process_group = self ._env .process_group ,
794
+ )
728
795
)
729
- )
730
796
731
797
def post_state_dict_hook (
732
798
module : ShardedEmbeddingCollection ,
@@ -741,6 +807,12 @@ def post_state_dict_hook(
741
807
) in module ._model_parallel_name_to_sharded_tensor .items ():
742
808
destination_key = f"{ prefix } embeddings.{ table_name } .weight"
743
809
destination [destination_key ] = sharded_t
810
+ for (
811
+ table_name ,
812
+ d_tensor ,
813
+ ) in module ._model_parallel_name_to_dtensor .items ():
814
+ destination_key = f"{ prefix } embeddings.{ table_name } .weight"
815
+ destination [destination_key ] = d_tensor
744
816
745
817
self .register_state_dict_pre_hook (self ._pre_state_dict_hook )
746
818
self ._register_state_dict_hook (post_state_dict_hook )
0 commit comments