Skip to content

Commit 452b98e

Browse files
committed
Fix dynamo tracing into AOTAutogradCache results
ghstack-source-id: c8acaf0 Pull Request resolved: #155251
1 parent cadcb5d commit 452b98e

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

test/dynamo/test_aot_autograd_cache.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22

3+
import copy
34
import os
45
import shutil
56
import unittest
@@ -822,6 +823,44 @@ def fn(a, b):
822823
self.assertEqual(a.grad, a2.grad)
823824
self.assertEqual(b.grad, b2.grad)
824825

826+
@inductor_config.patch("fx_graph_remote_cache", False)
827+
@inductor_config.patch({"fx_graph_cache": True})
828+
@functorch_config.patch({"enable_autograd_cache": True})
829+
@functorch_config.patch({"strict_autograd_cache": True})
830+
def test_autograd_no_dynamo_trace_backward(self):
831+
"""
832+
Test that dynamo does not trace into the backward compiled function,
833+
even on cache hit.
834+
"""
835+
torch._dynamo.eval_frame.clear_dynamo_tls()
836+
837+
@torch.compile
838+
def fn(x):
839+
# Calls x.sum().backward() during forward execution of fn
840+
(x_grad,) = torch.autograd.grad(x.sum(), x)
841+
return x_grad
842+
843+
a = torch.randn(10, 10, requires_grad=True, device="cpu")
844+
result = fn(a)
845+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
846+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
847+
# Backward of `sum` will run during execution of graph break
848+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
849+
traced_frame_infos = copy.deepcopy(
850+
torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
851+
)
852+
853+
torch._dynamo.reset()
854+
torch._dynamo.eval_frame.clear_dynamo_tls()
855+
result2 = fn(a)
856+
self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
857+
self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
858+
self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
859+
new_traced_frame_infos = torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos
860+
self.assertEqual(result, result2)
861+
# Dynamo should trace exactly the same frames on cache hit
862+
self.assertEqual(traced_frame_infos, new_traced_frame_infos)
863+
825864
@inductor_config.patch("fx_graph_remote_cache", False)
826865
@inductor_config.patch("fx_graph_cache", True)
827866
@functorch_config.patch({"enable_autograd_cache": True})

torch/_functorch/_aot_autograd/autograd_cache.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,14 @@ class CompiledBackward(GenericCompiledBackward[CompiledFxGraph], FxGraphCacheLoa
589589
def _is_backward(self) -> bool:
590590
return True
591591

592+
def post_compile(
593+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
594+
) -> CompiledFxGraph:
595+
compiled_bw = super().post_compile(result, fx_config)
596+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
597+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
598+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
599+
592600

593601
# Forward types don't have any extra parameters, so this is just a TypeAlias, in essence
594602
class BundledCompiledForward(CompiledFxGraphLoadable):
@@ -599,7 +607,13 @@ class BundledCompiledForward(CompiledFxGraphLoadable):
599607
class BundledCompiledBackward(
600608
GenericCompiledBackward[CompiledFxGraph], CompiledFxGraphLoadable
601609
):
602-
pass
610+
def post_compile(
611+
self, result: CompiledFxGraph, fx_config: _CompileFxKwargs
612+
) -> CompiledFxGraph:
613+
compiled_bw = super().post_compile(result, fx_config)
614+
# This is done by _wrapped_bw_compiler in torch/_dynamo/backends/common.py
615+
# But since on cache hit we do not call the bw_compiler, we need to reapply the disable
616+
return torch._dynamo.disable(compiled_bw, reason="do not trace generated backwards pass") # type: ignore[return-value]
603617

604618

605619
TForward = TypeVar("TForward", bound=InductorOutput)

0 commit comments

Comments
 (0)