diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 53cfcc432..41635b054 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -23,6 +23,7 @@ from torch import nn from torch.autograd.function import FunctionCtx +from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torchrec.distributed.batched_embedding_kernel import ( BaseBatchedEmbedding, @@ -69,7 +70,7 @@ def fx_wrap_tensor_view2d(x: torch.Tensor, dim0: int, dim1: int) -> torch.Tensor def _load_state_dict( emb_modules: "nn.ModuleList", - state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor]]", + state_dict: "OrderedDict[str, Union[torch.Tensor, ShardedTensor, DTensor]]", ) -> Tuple[List[str], List[str]]: missing_keys = [] unexpected_keys = list(state_dict.keys()) @@ -95,6 +96,22 @@ def _load_state_dict( ) dst_local_shard.tensor.detach().copy_(src_local_shard.tensor) + elif isinstance(dst_param, DTensor): + assert isinstance(src_param, DTensor) + dst_param = dst_param.to_local() + src_param = src_param.to_local() + assert len(dst_param.local_chunks) == len( # pyre-ignore[16] + src_param.local_chunks + ) + for dst_local_shard, src_local_shard in zip( + dst_param.to_local().local_shards(), # pyre-ignore[16] + src_param.to_local().local_shards(), + ): + assert ( + dst_local_shard.metadata.local_chunks + == src_local_shard.metadata.local_chunks + ) + dst_local_shard.detach().copy_(src_local_shard) else: assert isinstance(src_param, torch.Tensor) and isinstance( dst_param, torch.Tensor