Skip to content

Fix: no_grad with AMP bug #20921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from

Conversation

baskrahmer
Copy link
Contributor

@baskrahmer baskrahmer commented Jun 20, 2025

Fixes #20644

Note however that this would affect performance for other users, so the question is whether it is worth optimizing for this edge case that is fundamentally a torch bug.

cc @Borda


📚 Documentation preview 📚: https://pytorch-lightning--20921.org.readthedocs.build/en/20921/

@github-actions github-actions bot added the pl Generic label for PyTorch Lightning package label Jun 20, 2025
@baskrahmer baskrahmer force-pushed the fix/no-grad-amp-bug branch from 08508b6 to d18fb08 Compare June 20, 2025 13:46
@baskrahmer baskrahmer marked this pull request as ready for review June 20, 2025 15:33
Comment on lines +115 to +117
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return torch.autocast(
self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False
)
dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
return torch.autocast(self.device, dtype=dtype, cache_enabled=False)

@Borda
Copy link
Member

Borda commented Jun 23, 2025

Note however that this would affect performance for other users, so the question is whether it is worth optimizing for this edge case that is fundamentally a torch bug.

Then se shall report it and offer a fix in torch
Then, if it is accepted and released, we shall have a version switch in our codebase, so newer Torch versions won't need this compared to the old one. Does it...

BTW, have you measured the performance drop?
cc: @lantiga

@baskrahmer
Copy link
Contributor Author

@Borda it is a long-standing issue in torch. I can try to make a fix if I have some time, but I think it could be complex.

But I agree with you that it should be fixed in torch ideally. Just wanted to open this PR to showcase what a workaround on our end would look like. Shall I close it?

I haven't measured the performance drop since it will vary strongly across architectures and probably also hardware setups.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pl Generic label for PyTorch Lightning package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Computation graph not being built
2 participants