-
Notifications
You must be signed in to change notification settings - Fork 24.5k
Fix dynamo tracing into AOTAutogradCache results in cpu tensors #155251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155251
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit eacd200 with merge base d2a2bfc ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…pu tensors" On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
@@ -589,6 +589,14 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa | |||
def _is_backward(self) -> bool: | |||
return True | |||
|
|||
def post_compile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to put this in GenericCompiledBackward, but the post_compile its referencing in super() needs to refer to FxGraphCacheLoadable's post_compile, so I need to copy it twice for Bundled vs. non Bundled. Multiple inheritance is confusing 😅
compiled_bw = super().post_compile(result, fx_config) | ||
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py | ||
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable | ||
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: one convention in the codebase I've found helpful when there are two different areas of the codebase that need to be kept in sync is something like:
<file 1>
# Note [Wrapping bw_compiler in disable]
# <comment explaining why we wrap in disable
<file 2>
# See Note [Wrapping bw_compiler in disable]
That way if e.g. someone wants to tweak the disable behavior in the future, they can easily grep for the note name and find any other code locations that need to be kept in sync (random example where we do it here: https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L133C16-L133C49)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
…nsors" On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. Fixes #154536 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…tensors" On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. Fixes #154536 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
…nsors" On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. Fixes #154536 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
Rebased, let's see if I can get clean tests |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…rch#155251) On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. Fixes pytorch#154536 Pull Request resolved: pytorch#155251 Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
Stack from ghstack (oldest at bottom):
On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable:
pytorch/torch/_dynamo/backends/common.py
Line 76 in 05dd638
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.
On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already.
This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical.
Fixes #154536
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames