-
Notifications
You must be signed in to change notification settings - Fork 24.5k
[dtensor] fix simplefsdp mixed-precision training bugs #154975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/154975
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit c172c07 with merge base a7e496a ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the fix!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: pytorch#154975 Approved by: https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: pytorch#154975 Approved by: https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR: pytorch#150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly.  Pull Request resolved: pytorch#154975 Approved by: https://github.com/tianyu-l
This is a follow-up on the previous dtensor redistribute PR: #150740, which enables SimpleFSDP's mixed-precision training.
In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's
fully_shard
andreplicate
modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --local_tensor
is taken out again from the originalinput
. Thus, the dtensor used for communication has its original precision instead of usingforward_dtype
.This PR fixes this issue and corrects previously added test cases.
After fixing the bug, the loss curves of
fully_shard
andreplicate
mode match perfectly.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k