Skip to content

Commit a1a268a

Browse files
ruisizhang123pytorchmergebot
authored andcommitted
[dtensor] fix simplefsdp mixed-precision training bugs (#154975)
This is a follow-up on the previous dtensor redistribute PR: #150740, which enables SimpleFSDP's mixed-precision training. In the most recent integration in TorchTitan: pytorch/torchtitan#1250, we found some discrepancies between SimpleFSDP's `fully_shard` and `replicate` modes when MPT is enabled. After debugging, I found the problem is in dtensor redistribute --`local_tensor` is taken out again from the original `input`. Thus, the dtensor used for communication has its original precision instead of using `forward_dtype`. This PR fixes this issue and corrects previously added test cases. After fixing the bug, the loss curves of `fully_shard` and `replicate` mode match perfectly. ![loss](https://github.com/user-attachments/assets/a8faddae-a476-48c0-a411-3fe04d2233bd) Pull Request resolved: #154975 Approved by: https://github.com/tianyu-l
1 parent 2608927 commit a1a268a

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

test/distributed/tensor/test_redistribute.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,9 @@ def test_shard_to_replicate_forward_backward_datatype_conversion(self):
312312
backward_dtype=backward_dtype,
313313
)
314314
self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
315-
self.assertEqual(expected_tensor, reshard_dtensor.to_local())
315+
self.assertEqual(
316+
expected_tensor.to(forward_dtype), reshard_dtensor.to_local()
317+
)
316318
self.assertEqual(
317319
comm_mode.get_comm_counts()[funcol.all_gather_into_tensor], 1
318320
)

torch/distributed/tensor/_redistribute.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ def forward( # type: ignore[override]
318318
device_mesh, placements, tensor_meta=current_spec.tensor_meta
319319
)
320320

321-
local_tensor = input._local_tensor
322321
output = redistribute_local_tensor(
323322
local_tensor, current_spec, target_spec, async_op=async_op
324323
)

0 commit comments

Comments
 (0)