Skip to content

Commit bfbd95d

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Enable proper optimizer state storing + Test between batches (#3053)
Summary: Pull Request resolved: #3053 # Main Changes 1. Enable unit test with an adaptive optimizer `Adagrad` 1. Previously I tested the optimizer state with an optimizer `SGD` that is static throughout training so didn't actually test if we stored opt state, instead here I used the `Adagrad` which exposed the previous implementation did not properly store optimziers. 2. Properly store optimizer state in `update_optimizer_state` 2. Append optimizer tensors as inputs to the all2all call, then parse through the output tensors to store the right tensors. 2. Optimizer tensors that did not need to be sent to a new rank are persisted and resaved. 2. After new lookups are created, use `load_state_dict` to load in the saved optimizer state to the current optimizers. 3. Helpers & other small changes 3. Helper to compare optimizer tensors for unit tests 3. Update `DMP` reshard - optimizer saving to match the same fqn Reviewed By: aliafzal Differential Revision: D75565054 fbshipit-source-id: 672772481b03662e661d152449028a04c7ba69c0
1 parent 515b97b commit bfbd95d

File tree

5 files changed

+179
-24
lines changed

5 files changed

+179
-24
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30-
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
31-
DenseTableBatchedEmbeddingBagsCodegen,
32-
)
3330
from tensordict import TensorDict
3431
from torch import distributed as dist, nn, Tensor
3532
from torch.autograd.profiler import record_function
@@ -61,6 +58,7 @@
6158
get_largest_dims_from_sharding_plan_updates,
6259
shards_all_to_all,
6360
update_module_sharding_plan,
61+
update_optimizer_state_post_resharding,
6462
update_state_dict_post_resharding,
6563
)
6664
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
@@ -1535,7 +1533,7 @@ def update_shards(
15351533
return
15361534

15371535
current_state = self.state_dict()
1538-
# TODO: Save Optimizers
1536+
has_optimizer = len(self._optim._optims) > 0
15391537

15401538
# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
15411539
# TODO: Ensure lookup tensors are actually being deleted
@@ -1550,6 +1548,7 @@ def update_shards(
15501548
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
15511549
changed_sharding_params
15521550
)
1551+
old_optimizer_state = self._optim.state_dict() if has_optimizer else None
15531552

15541553
local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
15551554
module=self,
@@ -1560,16 +1559,7 @@ def update_shards(
15601559
extend_shard_name=self.extend_shard_name,
15611560
max_dim_0=max_dim_0,
15621561
max_dim_1=max_dim_1,
1563-
)
1564-
1565-
current_state = update_state_dict_post_resharding(
1566-
state_dict=current_state,
1567-
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1568-
output_tensor=local_output_tensor,
1569-
new_sharding_params=changed_sharding_params,
1570-
curr_rank=dist.get_rank(),
1571-
extend_shard_name=self.extend_shard_name,
1572-
max_dim_0=max_dim_0,
1562+
optimizer_state=old_optimizer_state,
15731563
)
15741564

15751565
for name, param in changed_sharding_params.items():
@@ -1615,8 +1605,6 @@ def update_shards(
16151605
if env.process_group and dist.get_backend(env.process_group) != "fake":
16161606
self._initialize_torch_state(skip_registering=True)
16171607

1618-
self.load_state_dict(current_state)
1619-
16201608
# update optimizer
16211609
optims = []
16221610
for lookup in self._lookups:
@@ -1635,6 +1623,35 @@ def update_shards(
16351623

16361624
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
16371625

1626+
if has_optimizer:
1627+
split_index = len(local_output_tensor) // 2
1628+
local_weight_tensors = local_output_tensor[:split_index]
1629+
local_optimizer_tensors = local_output_tensor[split_index:]
1630+
# Modifies new_opt_state in place and returns it
1631+
optimizer_state = update_optimizer_state_post_resharding(
1632+
old_opt_state=old_optimizer_state, # pyre-ignore
1633+
new_opt_state=copy.deepcopy(self._optim.state_dict()),
1634+
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1635+
output_tensor=local_optimizer_tensors,
1636+
max_dim_0=max_dim_0,
1637+
)
1638+
1639+
self._optim.load_state_dict(optimizer_state)
1640+
else:
1641+
local_weight_tensors = local_output_tensor
1642+
1643+
current_state = update_state_dict_post_resharding(
1644+
state_dict=current_state,
1645+
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1646+
output_tensor=local_weight_tensors,
1647+
new_sharding_params=changed_sharding_params,
1648+
curr_rank=dist.get_rank(),
1649+
extend_shard_name=self.extend_shard_name,
1650+
max_dim_0=max_dim_0,
1651+
)
1652+
1653+
self.load_state_dict(current_state)
1654+
16381655
update_module_sharding_plan(self, changed_sharding_params)
16391656
return
16401657

torchrec/distributed/model_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,10 @@ def reshard(
687687
self.device,
688688
)
689689

690-
self._optim: CombinedOptimizer = self._init_optim(self._dmp_wrapped_module)
690+
# Need to use .module to maintain FQN consistency
691+
self._optim: CombinedOptimizer = self._init_optim(
692+
self._dmp_wrapped_module.module # pyre-ignore
693+
)
691694
self._plan.plan[sharded_module_fqn] = sharded_module.module_sharding_plan
692695
return sharded_module
693696

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
import copy
11-
from typing import Any, Callable, Dict, List, Tuple
11+
from typing import Any, Callable, Dict, List, Optional, Tuple
1212

1313
import torch
1414
import torch.distributed as dist
@@ -84,6 +84,7 @@ def shards_all_to_all(
8484
max_dim_0: int,
8585
max_dim_1: int,
8686
extend_shard_name: Callable[[str], str] = lambda x: x,
87+
optimizer_state: Optional[Dict[str, Dict[str, Dict[str, ShardedTensor]]]] = None,
8788
) -> Tuple[OrderedShardNamesWithSizes, torch.Tensor]:
8889
"""
8990
Performs an all-to-all communication to redistribute shards across ranks based on new sharding parameters.
@@ -121,14 +122,18 @@ def shards_all_to_all(
121122
# Module sharding plan is used to get the source ranks for each shard
122123
assert hasattr(module, "module_sharding_plan")
123124

125+
has_optimizer = optimizer_state is not None
126+
124127
world_size = env.world_size
125128
rank = dist.get_rank()
126129
input_splits_per_rank = [[0] * world_size for _ in range(world_size)]
127130
output_splits_per_rank = [[0] * world_size for _ in range(world_size)]
128131

129132
output_tensor_tensor_count = 0
133+
output_optimizer_tensor_count = 0
130134
shard_names_to_lengths_by_src_rank = [[] for _ in range(world_size)]
131135
local_table_to_input_tensor_by_dst_rank = [[] for _ in range(world_size)]
136+
local_table_to_opt_by_dst_rank = [[] for _ in range(world_size)]
132137
for shard_name, param in changed_sharding_params.items():
133138
sharded_t = state_dict[extend_shard_name(shard_name)]
134139
assert param.ranks is not None
@@ -142,24 +147,47 @@ def shards_all_to_all(
142147
# index needed to distinguish between multiple shards
143148
# within the same shardedTensor for each table
144149
for i in range(len(src_ranks)):
150+
151+
# 1 to 1 mapping from src to dst
145152
dst_rank = dst_ranks[i]
146153
src_rank = src_ranks[i]
147154

148155
shard_size = sharded_t.metadata().shards_metadata[i].shard_sizes
149156
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
150157
output_splits_per_rank[dst_rank][src_rank] += max_dim_0
158+
if has_optimizer:
159+
input_splits_per_rank[src_rank][dst_rank] += max_dim_0
160+
output_splits_per_rank[dst_rank][src_rank] += max_dim_0
161+
162+
# If sending from current rank
151163
if src_rank == rank:
164+
if has_optimizer:
165+
# pyre-ignore
166+
local_optimizer = optimizer_state["state"][
167+
extend_shard_name(shard_name)
168+
][tmp_momentum_extender(shard_name)].local_shards()
169+
assert len(local_optimizer) == 1
170+
padded_local_optimizer = pad_tensor_to_max_dims(
171+
local_optimizer[0].tensor, max_dim_0, max_dim_1
172+
)
173+
local_table_to_opt_by_dst_rank[dst_rank].append(
174+
padded_local_optimizer
175+
)
152176
local_shards = sharded_t.local_shards()
153177
assert len(local_shards) == 1
154178
cur_t = pad_tensor_to_max_dims(
155-
sharded_t.local_shards()[0].tensor, max_dim_0, max_dim_1
179+
local_shards[0].tensor, max_dim_0, max_dim_1
156180
)
157181
local_table_to_input_tensor_by_dst_rank[dst_rank].append(cur_t)
182+
183+
# If recieving from current rank
158184
if dst_rank == rank:
159185
shard_names_to_lengths_by_src_rank[src_rank].append(
160186
(shard_name, shard_size)
161187
)
162188
output_tensor_tensor_count += max_dim_0
189+
if has_optimizer:
190+
output_optimizer_tensor_count += max_dim_0
163191

164192
local_input_splits = input_splits_per_rank[rank]
165193
local_output_splits = output_splits_per_rank[rank]
@@ -175,9 +203,23 @@ def shards_all_to_all(
175203
dim=0,
176204
)
177205

206+
for sub_l in local_table_to_opt_by_dst_rank:
207+
for shard_info in sub_l:
208+
local_input_tensor = torch.cat(
209+
(
210+
local_input_tensor,
211+
shard_info,
212+
),
213+
dim=0,
214+
)
215+
178216
max_embedding_size = max_dim_1
179217
local_output_tensor = torch.empty(
180-
[output_tensor_tensor_count, max_embedding_size], device=device
218+
[
219+
output_tensor_tensor_count + output_optimizer_tensor_count,
220+
max_embedding_size,
221+
],
222+
device=device,
181223
)
182224

183225
assert sum(local_output_splits) == len(local_output_tensor)
@@ -277,6 +319,50 @@ def update_state_dict_post_resharding(
277319
return state_dict
278320

279321

322+
def update_optimizer_state_post_resharding(
323+
old_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]],
324+
new_opt_state: Dict[str, Dict[str, Dict[str, ShardedTensor]]],
325+
ordered_shard_names_and_lengths: OrderedShardNamesWithSizes,
326+
output_tensor: torch.Tensor,
327+
max_dim_0: int,
328+
) -> Dict[str, Dict[str, Dict[str, ShardedTensor]]]:
329+
new_opt_state_state = new_opt_state["state"]
330+
old_opt_state_state = old_opt_state["state"]
331+
332+
# Remove padding and store tensors by shard name
333+
slice_index = 0
334+
shard_name_to_local_output_tensor: Dict[str, torch.Tensor] = {}
335+
for shard_name, shard_size in ordered_shard_names_and_lengths:
336+
end_slice_index = slice_index + max_dim_0
337+
cur_t = output_tensor[slice_index:end_slice_index]
338+
cur_t = pad_tensor_to_max_dims(
339+
cur_t, shard_size[0], shard_size[1], remove_padding=True
340+
)
341+
shard_name_to_local_output_tensor[shard_name] = cur_t
342+
slice_index = end_slice_index
343+
344+
for extended_shard_name, item in new_opt_state_state.items():
345+
if extended_shard_name in old_opt_state_state:
346+
new_opt_state_state[extended_shard_name] = old_opt_state_state[
347+
extended_shard_name
348+
]
349+
else:
350+
shard_name = extract_shard_name(extended_shard_name)
351+
momentum_name = tmp_momentum_extender(shard_name)
352+
sharded_t = item[momentum_name]
353+
assert len(sharded_t._local_shards) == 1
354+
# TODO: support multiple shards in CW sharding
355+
sharded_t._local_shards = [
356+
Shard(
357+
tensor=shard_name_to_local_output_tensor[shard_name],
358+
metadata=shard.metadata,
359+
)
360+
for shard in sharded_t._local_shards
361+
]
362+
363+
return new_opt_state
364+
365+
280366
def update_module_sharding_plan(
281367
module: ShardedModule[Any, Any, Any, Any], # pyre-ignore
282368
changed_sharding_params: Dict[str, ParameterSharding],
@@ -388,3 +474,16 @@ def output_sharding_plan_delta(
388474
if v.ranks != old_plan[k].ranks
389475
}
390476
)
477+
478+
479+
"""
480+
Utils for Optimizer State accessing
481+
"""
482+
483+
484+
def tmp_momentum_extender(name: str) -> str:
485+
return name + ".momentum1"
486+
487+
488+
def extract_shard_name(name: str) -> str:
489+
return name.split(".")[-2]

torchrec/distributed/test_utils/test_sharding.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,8 @@ def dynamic_sharding_test(
541541
)
542542

543543
local_m1_dmp.reshard("sparse.ebc", new_module_sharding_plan_delta)
544+
# Must recreate local_m1_opt, because current local_m1_opt is a copy of underlying fused_opt
545+
local_m1_opt = CombinedOptimizer([local_m1_dmp.fused_optimizer, dense_m1_optim])
544546

545547
local_m1_pred = gen_full_pred_after_one_step(
546548
local_m1_dmp, local_m1_opt, local_input_1
@@ -954,7 +956,12 @@ def gen_full_pred_after_one_step(
954956
opt: torch.optim.Optimizer,
955957
input: ModelInput,
956958
skip_inference: bool = False,
959+
skip_training: bool = False,
957960
) -> torch.Tensor:
961+
if skip_training:
962+
model.train(False)
963+
output = model(input)
964+
return output
958965
# Run a single training step of the global model.
959966
opt.zero_grad()
960967
model.train(True)
@@ -1120,3 +1127,32 @@ def generate_rank_placements(
11201127
placement = sorted(random.sample(range(world_size), ranks_per_table))
11211128
placements.append(placement)
11221129
return placements
1130+
1131+
1132+
def compare_opt_local_t(
1133+
opt_1: CombinedOptimizer,
1134+
opt_2: CombinedOptimizer,
1135+
table_id: int,
1136+
rtol: float = 1e-4,
1137+
atol: float = 1e-4,
1138+
) -> None:
1139+
"""
1140+
Helper function to compare the optimizer state of two models after one training step.
1141+
Useful for debugging sharding tests to see which model weights are different
1142+
"""
1143+
# TODO: update logic to be generic other embedding modules
1144+
t1 = (
1145+
opt_1.state_dict()["state"][
1146+
"sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight"
1147+
]["table_" + str(table_id) + ".momentum1"]
1148+
.local_shards()[0]
1149+
.tensor
1150+
)
1151+
t2 = (
1152+
opt_2.state_dict()["state"][
1153+
"sparse.ebc.embedding_bags.table_" + str(table_id) + ".weight"
1154+
]["table_" + str(table_id) + ".momentum1"]
1155+
.local_shards()[0]
1156+
.tensor
1157+
)
1158+
torch.testing.assert_close(t1, t2, rtol=rtol, atol=atol)

torchrec/distributed/tests/test_dynamic_sharding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
# pyre-strict
99

1010

11-
import copy
12-
1311
import random
1412
import unittest
1513

@@ -21,7 +19,7 @@
2119

2220
from hypothesis import assume, given, settings, Verbosity
2321

24-
from torch import nn
22+
from torch import nn, optim
2523

2624
from torchrec import distributed as trec_dist, EmbeddingBagCollection, KeyedJaggedTensor
2725
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
@@ -530,9 +528,11 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared):
530528
apply_optimizer_in_backward_config=st.sampled_from(
531529
[
532530
None,
531+
{
532+
"embedding_bags": (optim.Adagrad, {"lr": 0.04}),
533+
},
533534
{
534535
"embedding_bags": (torch.optim.SGD, {"lr": 0.01}),
535-
"embeddings": (torch.optim.SGD, {"lr": 0.2}),
536536
},
537537
]
538538
),

0 commit comments

Comments
 (0)