|
1 | 1 | # Owner(s): ["module: dynamo"]
|
2 | 2 |
|
| 3 | +import copy |
3 | 4 | import os
|
4 | 5 | import shutil
|
5 | 6 | import unittest
|
@@ -822,6 +823,44 @@ def fn(a, b):
|
822 | 823 | self.assertEqual(a.grad, a2.grad)
|
823 | 824 | self.assertEqual(b.grad, b2.grad)
|
824 | 825 |
|
| 826 | + @inductor_config.patch("fx_graph_remote_cache", False) |
| 827 | + @inductor_config.patch({"fx_graph_cache": True}) |
| 828 | + @functorch_config.patch({"enable_autograd_cache": True}) |
| 829 | + @functorch_config.patch({"strict_autograd_cache": True}) |
| 830 | + def test_autograd_no_dynamo_trace_backward(self): |
| 831 | + """ |
| 832 | + Test that dynamo does not trace into the backward compiled function, |
| 833 | + even on cache hit. |
| 834 | + """ |
| 835 | + torch._dynamo.eval_frame.clear_dynamo_tls() |
| 836 | + |
| 837 | + @torch.compile |
| 838 | + def fn(x): |
| 839 | + # Calls x.sum().backward() during forward execution of fn |
| 840 | + (x_grad,) = torch.autograd.grad(x.sum(), x) |
| 841 | + return x_grad |
| 842 | + |
| 843 | + a = torch.randn(10, 10, requires_grad=True, device="cpu") |
| 844 | + result = fn(a) |
| 845 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) |
| 846 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0) |
| 847 | + # Backward of `sum` will run during execution of graph break |
| 848 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) |
| 849 | + traced_frame_infos = copy.deepcopy( |
| 850 | + torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos |
| 851 | + ) |
| 852 | + |
| 853 | + torch._dynamo.reset() |
| 854 | + torch._dynamo.eval_frame.clear_dynamo_tls() |
| 855 | + result2 = fn(a) |
| 856 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1) |
| 857 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1) |
| 858 | + self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1) |
| 859 | + new_traced_frame_infos = torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos |
| 860 | + self.assertEqual(result, result2) |
| 861 | + # Dynamo should trace exactly the same frames on cache hit |
| 862 | + self.assertEqual(traced_frame_infos, new_traced_frame_infos) |
| 863 | + |
825 | 864 | @inductor_config.patch("fx_graph_remote_cache", False)
|
826 | 865 | @inductor_config.patch("fx_graph_cache", True)
|
827 | 866 | @functorch_config.patch({"enable_autograd_cache": True})
|
|
0 commit comments