Skip to content

Fix dynamo tracing into AOTAutogradCache results in cpu tensors #155251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

jamesjwu
Copy link
Contributor

@jamesjwu jamesjwu commented Jun 5, 2025

Stack from ghstack (oldest at bottom):

On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable:

reason="do not trace generated backwards pass",

This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already.

I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical.

Fixes #154536

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

@jamesjwu jamesjwu requested a review from bdhirsh as a code owner June 5, 2025 19:23
Copy link

pytorch-bot bot commented Jun 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/155251

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit eacd200 with merge base d2a2bfc (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jamesjwu added a commit that referenced this pull request Jun 5, 2025
ghstack-source-id: d08b283
Pull Request resolved: #155251
@jamesjwu jamesjwu added the topic: not user facing topic category label Jun 5, 2025
@jamesjwu jamesjwu requested review from oulgen and anijain2305 June 5, 2025 19:27
@jamesjwu jamesjwu changed the title Fix dynamo tracing into AOTAutogradCache results Fix dynamo tracing into AOTAutogradCache results in cpu tensors Jun 5, 2025
…pu tensors"


On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: 
https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. 


```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. 



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Jun 5, 2025
ghstack-source-id: 2908aa0
Pull Request resolved: #155251
@@ -589,6 +589,14 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
def _is_backward(self) -> bool:
return True

def post_compile(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to put this in GenericCompiledBackward, but the post_compile its referencing in super() needs to refer to FxGraphCacheLoadable's post_compile, so I need to copy it twice for Bundled vs. non Bundled. Multiple inheritance is confusing 😅

compiled_bw = super().post_compile(result, fx_config)
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: one convention in the codebase I've found helpful when there are two different areas of the codebase that need to be kept in sync is something like:

<file 1>
# Note [Wrapping bw_compiler in disable]
# <comment explaining why we wrap in disable

<file 2>
# See Note [Wrapping bw_compiler in disable]

That way if e.g. someone wants to tweak the disable behavior in the future, they can easily grep for the note name and find any other code locations that need to be kept in sync (random example where we do it here: https://github.com/pytorch/pytorch/blob/main/c10/core/DispatchKey.h#L133C16-L133C49)

Copy link
Contributor

@bdhirsh bdhirsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

…nsors"


On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: 
https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. 


```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. 

Fixes #154536


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Jun 5, 2025
ghstack-source-id: c8acaf0
Pull Request resolved: #155251
…tensors"


On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: 
https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. 


```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. 

Fixes #154536


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Jun 5, 2025
ghstack-source-id: 767701a
Pull Request resolved: #155251
@jamesjwu jamesjwu added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 6, 2025
…nsors"


On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: 
https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. 


```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. 

Fixes #154536


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames

[ghstack-poisoned]
jamesjwu added a commit that referenced this pull request Jun 6, 2025
ghstack-source-id: 57ec57c
Pull Request resolved: #155251
@jamesjwu
Copy link
Contributor Author

jamesjwu commented Jun 6, 2025

Rebased, let's see if I can get clean tests

@jamesjwu
Copy link
Contributor Author

jamesjwu commented Jun 9, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

thatgeeman pushed a commit to thatgeeman/pytorch-docathon that referenced this pull request Jun 15, 2025
…rch#155251)

On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable:
https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76
This disables dynamo in the bw_compiler but also disables the runnable the compiler returns.

On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already.

```
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * fn /home/jjwu/test.py:9
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517]   * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35
I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ]

```

This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical.

Fixes pytorch#154536

Pull Request resolved: pytorch#155251
Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants