Skip to content

add DTensor to optimizer state dict #2585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 47 additions & 8 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@
PartiallyMaterializedTensor,
)
from torch import nn
from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard
from torchrec.distributed.comm import get_local_rank, get_node_group_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict
from torchrec.distributed.embedding_types import (
compute_kernel_to_embedding_location,
DTensorMetadata,
GroupedEmbeddingConfig,
)
from torchrec.distributed.shards_wrapper import LocalShardsWrapper
from torchrec.distributed.types import (
Shard,
ShardedTensor,
Expand Down Expand Up @@ -213,6 +216,7 @@ class ShardParams:
optimizer_states: List[Optional[Tuple[torch.Tensor]]]
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]
dtensor_metadata: List[DTensorMetadata]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
Expand Down Expand Up @@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
continue
if table_config.name not in table_to_shard_params:
table_to_shard_params[table_config.name] = ShardParams(
optimizer_states=[], local_metadata=[], embedding_weights=[]
optimizer_states=[],
local_metadata=[],
embedding_weights=[],
dtensor_metadata=[],
)
optimizer_state_values = None
if optimizer_states:
Expand All @@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_to_shard_params[table_config.name].local_metadata.append(
local_metadata
)
table_to_shard_params[table_config.name].dtensor_metadata.append(
table_config.dtensor_metadata
)
table_to_shard_params[table_config.name].embedding_weights.append(weight)

seen_tables = set()
Expand Down Expand Up @@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
# pyre-ignore
def get_sharded_optim_state(
momentum_idx: int, state_key: str
) -> ShardedTensor:
) -> Union[ShardedTensor, DTensor]:
assert momentum_idx > 0
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata
Expand Down Expand Up @@ -528,12 +538,41 @@ def get_sharded_optim_state(
)
)

# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)
# Convert optimizer state to DTensor if enabled
if table_config.dtensor_metadata:
# if rowwise state we do Shard(0), regardless of how the table is sharded
if optim_state.dim() == 1:
stride = (1,)
placements = (
(Replicate(), DTensorShard(0))
if table_config.dtensor_metadata.mesh.ndim == 2
else (DTensorShard(0),)
)
else:
stride = table_config.dtensor_metadata.stride
placements = table_config.dtensor_metadata.placements

return DTensor.from_local(
local_tensor=LocalShardsWrapper(
local_shards=[x.tensor for x in momentum_local_shards],
local_offsets=[ # pyre-ignore[6]
x.metadata.shard_offsets
for x in momentum_local_shards
],
),
device_mesh=table_config.dtensor_metadata.mesh,
placements=placements,
shape=optimizer_sharded_tensor_metadata.size,
stride=stride,
run_check=False,
)
else:
# TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata.
return ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=momentum_local_shards,
sharded_tensor_metadata=optimizer_sharded_tensor_metadata,
process_group=self._pg,
)

num_states: int = min(
# pyre-ignore
Expand Down
52 changes: 39 additions & 13 deletions torchrec/distributed/shards_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,15 @@ def __new__(

# we calculate the total tensor size by "concat" on second tensor dimension
cat_tensor_shape = list(local_shards[0].size())
if len(local_shards) > 1: # column-wise sharding
if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[1] += shard.size()[1]

# in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension
if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding
for shard in local_shards[1:]:
cat_tensor_shape[0] += shard.size()[0]

wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
wrapper_shape = torch.Size(cat_tensor_shape)
chunks_meta = [
Expand Down Expand Up @@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
aten.equal.default: cls.handle_equal,
aten.detach.default: cls.handle_detach,
aten.clone.default: cls.handle_clone,
aten.new_empty.default: cls.handle_new_empty,
}

if func in dispatcher:
Expand Down Expand Up @@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs):
def handle_view(args, kwargs):
view_shape = args[1]
res_shards_list = []
if (
len(args[0].local_shards()) > 1
and args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
):
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
if len(args[0].local_shards()) > 1:
if args[0].local_shards()[0].ndim == 2:
assert (
args[0].storage_metadata().size[0] == view_shape[0]
and args[0].storage_metadata().size[1] == view_shape[1]
)
# This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on
# init calls view_as() on the global tensor shape
# will fail because the view shape is not applicable to individual shards.
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
elif args[0].local_shards()[0].ndim == 1:
assert args[0].storage_metadata().size[0] == view_shape[0]
# This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded
res_shards_list = [
aten.view.default(shard, shard.shape, **kwargs)
for shard in args[0].local_shards()
]
else:
raise NotImplementedError("No support for view on tensors ndim > 2")
else:
# view is called per shard
res_shards_list = [
Expand Down Expand Up @@ -220,6 +236,16 @@ def handle_clone(args, kwargs):
]
return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())

@staticmethod
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def handle_new_empty(args, kwargs):
self_ls = args[0]
return LocalShardsWrapper(
[torch.empty_like(shard) for shard in self_ls._local_shards],
self_ls.local_offsets(),
)

@property
def device(self) -> torch._C.device: # type: ignore[override]
return (
Expand Down
21 changes: 20 additions & 1 deletion torchrec/optim/keyed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from torch import optim
from torch.distributed._shard.sharded_tensor import ShardedTensor

from torch.distributed.tensor import DTensor

OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer]

Expand Down Expand Up @@ -111,6 +111,8 @@ def _update_param_state_dict_object(
param_state_dict_to_load: Dict[str, Any],
parent_keys: List[Union[str, int, float, bool, None]],
) -> None:
# Import at function level to avoid circular dependency.
from torchrec.distributed.shards_wrapper import LocalShardsWrapper

for k, v in current_param_state_dict.items():
new_v = param_state_dict_to_load[k]
Expand All @@ -134,6 +136,23 @@ def _update_param_state_dict_object(
)
for shard, new_shard in zip(v.local_shards(), new_v.local_shards()):
shard.tensor.detach().copy_(new_shard.tensor)
elif isinstance(v, DTensor):
assert isinstance(new_v, DTensor)
if isinstance(v.to_local(), LocalShardsWrapper):
assert isinstance(new_v.to_local(), LocalShardsWrapper)
num_shards = len(v.to_local().local_shards()) # pyre-ignore[16]
num_new_shards = len(new_v.to_local().local_shards())
if num_shards != num_new_shards:
raise ValueError(
f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}"
)
for shard, new_shard in zip(
v.to_local().local_shards(), new_v.to_local().local_shards()
):
shard.detach().copy_(new_shard)
else:
assert isinstance(new_v.to_local(), torch.Tensor)
v.detach().copy_(new_v)
elif isinstance(v, torch.Tensor):
v.detach().copy_(new_v)
else:
Expand Down
Loading