File tree Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Expand file tree Collapse file tree 1 file changed +9
-9
lines changed Original file line number Diff line number Diff line change @@ -98,18 +98,18 @@ def _load_state_dict(
98
98
dst_local_shard .tensor .detach ().copy_ (src_local_shard .tensor )
99
99
elif isinstance (dst_param , DTensor ):
100
100
assert isinstance (src_param , DTensor )
101
- dst_param = dst_param .to_local ()
102
- src_param = src_param .to_local ()
103
- assert len (dst_param .local_chunks ) == len ( # pyre-ignore[16]
104
- src_param .local_chunks
101
+ assert len (dst_param .to_local ().local_chunks ) == len ( # pyre-ignore[16]
102
+ src_param .to_local ().local_chunks
105
103
)
106
- for dst_local_shard , src_local_shard in zip (
107
- dst_param .to_local ().local_shards (), # pyre-ignore[16]
108
- src_param .to_local ().local_shards (),
104
+ for i , (dst_local_shard , src_local_shard ) in enumerate (
105
+ zip (
106
+ dst_param .to_local ().local_shards (), # pyre-ignore[16]
107
+ src_param .to_local ().local_shards (),
108
+ )
109
109
):
110
110
assert (
111
- dst_local_shard . metadata .local_chunks
112
- == src_local_shard . metadata .local_chunks
111
+ dst_param . to_local () .local_chunks [ i ]
112
+ == src_param . to_local () .local_chunks [ i ]
113
113
)
114
114
dst_local_shard .detach ().copy_ (src_local_shard )
115
115
else :
You can’t perform that action at this time.
0 commit comments