Skip to content

Commit f9dd63c

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Fix all_to_all not using env pg (#2918)
Summary: Pull Request resolved: #2918 Reviewed By: aliafzal Differential Revision: D73678043 fbshipit-source-id: 13afe85e5537b20afe843461824be1dd52ff230c
1 parent a5de563 commit f9dd63c

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torchrec/distributed/sharding/dynamic_sharding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def shards_all_to_all(
7373
sharded_t = state_dict[extend_shard_name(shard_name)]
7474
assert param.ranks is not None
7575
dst_ranks = param.ranks
76-
state_dict[extend_shard_name(shard_name)]
7776
# pyre-ignore
7877
src_ranks = module.module_sharding_plan[shard_name].ranks
7978

@@ -140,7 +139,7 @@ def shards_all_to_all(
140139
input=local_input_tensor,
141140
output_split_sizes=local_output_splits,
142141
input_split_sizes=local_input_splits,
143-
group=dist.group.WORLD,
142+
group=env.process_group, # TODO: 2D uses env.sharding_pg
144143
)
145144

146145
flattened_output_names_lengths = [

0 commit comments

Comments
 (0)