We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a5de563 commit f9dd63cCopy full SHA for f9dd63c
torchrec/distributed/sharding/dynamic_sharding.py
@@ -73,7 +73,6 @@ def shards_all_to_all(
73
sharded_t = state_dict[extend_shard_name(shard_name)]
74
assert param.ranks is not None
75
dst_ranks = param.ranks
76
- state_dict[extend_shard_name(shard_name)]
77
# pyre-ignore
78
src_ranks = module.module_sharding_plan[shard_name].ranks
79
@@ -140,7 +139,7 @@ def shards_all_to_all(
140
139
input=local_input_tensor,
141
output_split_sizes=local_output_splits,
142
input_split_sizes=local_input_splits,
143
- group=dist.group.WORLD,
+ group=env.process_group, # TODO: 2D uses env.sharding_pg
144
)
145
146
flattened_output_names_lengths = [
0 commit comments