Skip to content

Commit 3e82808

Browse files
jamesjwuthatgeeman
authored andcommitted
Fix dynamo tracing into AOTAutogradCache results in cpu tensors (pytorch#155251)
On this line, we see that the bw_compiler that dynamo uses for AotAutograd automatically disables the backward runnable: https://github.com/pytorch/pytorch/blob/05dd638ee98b36254c84095894c36fd0e7d95544/torch/_dynamo/backends/common.py#L76 This disables dynamo in the bw_compiler but also disables the runnable the compiler returns. On a AOTAutogradCache hit, however, we never call the bw_compiler! So we don't disable dynamo properly. This only has an effect on certain cases of cpu tensors' backwards, where the backward is being done in python land, and dynamo unnecessarily tries to trace through the inductor generated code. It also only matters if the backward is being accessed outside of dynamo itself (say, in a graph break in eager mode), since dynamo properly disables the forward function already. ``` I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] TorchDynamo attempted to trace the following frames: [ I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * fn /home/jjwu/test.py:9 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * cast /data/users/jjwu/a/pytorch-env/lib/python3.10/typing.py:1737 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] * call /tmp/torchinductor_jjwu/rq/crq327nhoyjzog5n3qlchauucdrunrtutwmmoh7ipoe2ngnson5s.py:35 I0605 09:58:40.135000 3981970 torch/_dynamo/eval_frame.py:517] ] ``` This PR fixes the issue and adds a unit test showing that with or without cache hit, the frames dynamo is tracing is identical. Fixes pytorch#154536 Pull Request resolved: pytorch#155251 Approved by: https://github.com/bdhirsh, https://github.com/anijain2305
1 parent c2b27ab commit 3e82808

File tree

3 files changed

+60
-2
lines changed

3 files changed

+60
-2
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import copy
34
import os
45
import shutil
56
import unittest
@@ -822,6 +823,44 @@ def fn(a, b):
822823
self.assertEqual(a.grad, a2.grad)
823824
self.assertEqual(b.grad, b2.grad)
824825

826+
@inductor_config.patch("fx_graph_remote_cache", False)
827+
@inductor_config.patch({"fx_graph_cache": True})
828+
@functorch_config.patch({"enable_autograd_cache": True})
829+
@functorch_config.patch({"strict_autograd_cache": True})
830+
def test_autograd_no_dynamo_trace_backward(self):
831+
"""
832+
Test that dynamo does not trace into the backward compiled function,
833+
even on cache hit.
834+
"""
835+
torch._dynamo.eval_frame.clear_dynamo_tls()
836+
837+
@torch.compile
838+
def fn(x):
839+
# Calls x.sum().backward() during forward execution of fn
840+
(x_grad,) = torch.autograd.grad(x.sum(), x)
841+
return x_grad
842+
843+
a = torch.randn(10, 10, requires_grad=True, device="cpu")
844+
result = fn(a)
845+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
846+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
847+
# Backward of `sum` will run during execution of graph break
848+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
849+
traced_frame_infos = copy.deepcopy(
850+
torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
851+
)
852+
853+
torch._dynamo.reset()
854+
torch._dynamo.eval_frame.clear_dynamo_tls()
855+
result2 = fn(a)
856+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
857+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
858+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
859+
new_traced_frame_infos = torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
860+
self.assertEqual(result, result2)
861+
# Dynamo should trace exactly the same frames on cache hit
862+
self.assertEqual(traced_frame_infos, new_traced_frame_infos)
863+
825864
@inductor_config.patch("fx_graph_remote_cache", False)
826865
@inductor_config.patch("fx_graph_cache", True)
827866
@functorch_config.patch({"enable_autograd_cache": True})

torch/_dynamo/backends/common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def __call__(self, gm: torch.fx.GraphModule, example_inputs, **kwargs):
6868

6969
def wrap_bw_compiler(bw_compiler_fn):
7070
def _wrapped_bw_compiler(*args, **kwargs):
71-
# stop TorchDynamo from trying to compile our generated backwards pass
71+
# Note [Wrapping bw_compiler in disable]
72+
# The two disables here:
73+
# - stop TorchDynamo from trying to compile the bw_compiler function itself
74+
# - stop TorchDynamo from trying to compile our the generated backwards pass bw_compiler produces
7275
return disable(
7376
disable(
7477
bw_compiler_fn, reason="do not trace backward compiler function"

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,15 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
589589
def _is_backward(self) -> bool:
590590
return True
591591

592+
def post_compile(
593+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
594+
) -> CompiledFxGraph:
595+
compiled_bw = super().post_compile(result, fx_config)
596+
# See note [Wrapping bw_compiler in disable]
597+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
598+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
599+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
600+
592601

593602
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
594603
class BundledCompiledForward(CompiledFxGraphLoadable):
@@ -599,7 +608,14 @@ class BundledCompiledForward(CompiledFxGraphLoadable):
599608
class BundledCompiledBackward(
600609
GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable
601610
):
602-
pass
611+
def post_compile(
612+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
613+
) -> CompiledFxGraph:
614+
compiled_bw = super().post_compile(result, fx_config)
615+
# See note [Wrapping bw_compiler in disable]
616+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
617+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
618+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
603619

604620

605621
TForward = TypeVar("TForward", bound=InductorOutput)

0 commit comments

Comments
 (0)