23
23
from torch import nn
24
24
25
25
from torch .autograd .function import FunctionCtx
26
+ from torch .distributed ._tensor import DTensor
26
27
from torch .nn .modules .module import _IncompatibleKeys
27
28
from torchrec .distributed .batched_embedding_kernel import (
28
29
BaseBatchedEmbedding ,
@@ -69,7 +70,7 @@ def fx_wrap_tensor_view2d(x: torch.Tensor, dim0: int, dim1: int) -> torch.Tensor
69
70
70
71
def _load_state_dict (
71
72
emb_modules : "nn.ModuleList" ,
72
- state_dict : "OrderedDict[str, Union[torch.Tensor, ShardedTensor]]" ,
73
+ state_dict : "OrderedDict[str, Union[torch.Tensor, ShardedTensor, DTensor ]]" ,
73
74
) -> Tuple [List [str ], List [str ]]:
74
75
missing_keys = []
75
76
unexpected_keys = list (state_dict .keys ())
@@ -95,6 +96,22 @@ def _load_state_dict(
95
96
)
96
97
97
98
dst_local_shard .tensor .detach ().copy_ (src_local_shard .tensor )
99
+ elif isinstance (dst_param , DTensor ):
100
+ assert isinstance (src_param , DTensor )
101
+ dst_param = dst_param .to_local ()
102
+ src_param = src_param .to_local ()
103
+ assert len (dst_param .local_chunks ) == len ( # pyre-ignore[16]
104
+ src_param .local_chunks
105
+ )
106
+ for dst_local_shard , src_local_shard in zip (
107
+ dst_param .to_local ().local_shards (), # pyre-ignore[16]
108
+ src_param .to_local ().local_shards (),
109
+ ):
110
+ assert (
111
+ dst_local_shard .metadata .local_chunks
112
+ == src_local_shard .metadata .local_chunks
113
+ )
114
+ dst_local_shard .detach ().copy_ (src_local_shard )
98
115
else :
99
116
assert isinstance (src_param , torch .Tensor ) and isinstance (
100
117
dst_param , torch .Tensor
0 commit comments