Skip to content

Commit e7a8e3e

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
fix variable assignment in DTensor emb lookup (#2169)
Summary: Pull Request resolved: #2169 tsia Reviewed By: PaulZhang12 Differential Revision: D59006643
1 parent 71ca217 commit e7a8e3e

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,18 +98,18 @@ def _load_state_dict(
9898
dst_local_shard.tensor.detach().copy_(src_local_shard.tensor)
9999
elif isinstance(dst_param, DTensor):
100100
assert isinstance(src_param, DTensor)
101-
dst_param = dst_param.to_local()
102-
src_param = src_param.to_local()
103-
assert len(dst_param.local_chunks) == len( # pyre-ignore[16]
104-
src_param.local_chunks
101+
assert len(dst_param.to_local().local_chunks) == len( # pyre-ignore[16]
102+
src_param.to_local().local_chunks
105103
)
106-
for dst_local_shard, src_local_shard in zip(
107-
dst_param.to_local().local_shards(), # pyre-ignore[16]
108-
src_param.to_local().local_shards(),
104+
for i, (dst_local_shard, src_local_shard) in enumerate(
105+
zip(
106+
dst_param.to_local().local_shards(), # pyre-ignore[16]
107+
src_param.to_local().local_shards(),
108+
)
109109
):
110110
assert (
111-
dst_local_shard.metadata.local_chunks
112-
== src_local_shard.metadata.local_chunks
111+
dst_param.to_local().local_chunks[i]
112+
== src_param.to_local().local_chunks[i]
113113
)
114114
dst_local_shard.detach().copy_(src_local_shard)
115115
else:

0 commit comments

Comments
 (0)