Skip to content

[dtensor] fix simplefsdp mixed-precision training bugs #154975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Jun 3, 2025

This is a follow-up on the previous dtensor redistribute PR: #150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's fully_shard and replicate modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --local_tensor is taken out again from the original input. Thus, the dtensor used for communication has its original precision instead of using forward_dtype.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of fully_shard and replicate mode match perfectly.

loss

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

Copy link

pytorch-bot bot commented Jun 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154975

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit c172c07 with merge base a7e496a (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 3, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the fix!

@tianyu-l tianyu-l added release notes: distributed (dtensor) release notes category ciflow/trunk Trigger trunk jobs on your pull request labels Jun 3, 2025
@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

qingyi-yan pushed a commit to qingyi-yan/pytorch that referenced this pull request Jun 3, 2025
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: pytorch#154975
Approved by: https://github.com/tianyu-l
iupaikov-amd pushed a commit to ROCm/pytorch that referenced this pull request Jun 4, 2025
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: pytorch#154975
Approved by: https://github.com/tianyu-l
angelayi pushed a commit to angelayi/pytorch that referenced this pull request Jun 5, 2025
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training.

In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`.

This PR fixes this issue and corrects previously added test cases.

After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.

![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd)

Pull Request resolved: pytorch#154975
Approved by: https://github.com/tianyu-l
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (dtensor) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants