Skip to content

Commit f515d15

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
forward fix embedding_lookup to include DTensor (#2167)
Summary: Pull Request resolved: #2167 TSIA, jobs failing because DTensor path does not exist Reviewed By: TroyGarden Differential Revision: D58983960 fbshipit-source-id: 3a7dcf4547b869ae411f5abc3011a70d91d6554a
1 parent 6b7ce03 commit f515d15

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

torchrec/distributed/embedding_lookup.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch import nn
2424

2525
from torch.autograd.function import FunctionCtx
26+
from torch.distributed._tensor import DTensor
2627
from torch.nn.modules.module import _IncompatibleKeys
2728
from torchrec.distributed.batched_embedding_kernel import (
2829
BaseBatchedEmbedding,
@@ -69,7 +70,7 @@ def fx_wrap_tensor_view2d(x: torch.Tensor, dim0: int, dim1: int) -> torch.Tensor
6970

7071
def _load_state_dict(
7172
emb_modules: "nn.ModuleList",
72-
state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor]]",
73+
state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor, DTensor]]",
7374
) -> Tuple[List[str], List[str]]:
7475
missing_keys = []
7576
unexpected_keys = list(state_dict.keys())
@@ -95,6 +96,22 @@ def _load_state_dict(
9596
)
9697

9798
dst_local_shard.tensor.detach().copy_(src_local_shard.tensor)
99+
elif isinstance(dst_param, DTensor):
100+
assert isinstance(src_param, DTensor)
101+
dst_param = dst_param.to_local()
102+
src_param = src_param.to_local()
103+
assert len(dst_param.local_chunks) == len( # pyre-ignore[16]
104+
src_param.local_chunks
105+
)
106+
for dst_local_shard, src_local_shard in zip(
107+
dst_param.to_local().local_shards(), # pyre-ignore[16]
108+
src_param.to_local().local_shards(),
109+
):
110+
assert (
111+
dst_local_shard.metadata.local_chunks
112+
== src_local_shard.metadata.local_chunks
113+
)
114+
dst_local_shard.detach().copy_(src_local_shard)
98115
else:
99116
assert isinstance(src_param, torch.Tensor) and isinstance(
100117
dst_param, torch.Tensor

0 commit comments

Comments
 (0)