Skip to content

Commit e6eab49

Browse files
jon-chuangpytorchmergebot
authored andcommitted
[dynamo] graph break on setattr requires_grad (pytorch#113163)
Main: `RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn` This PR: graph breaks and eager applies the mutation, new tensors are tracked Fixes pytorch#109505 (the original bug does not occur, but a new bug where the mutation isn't applied - because AOTAutograd is not `requires_grad` mutation aware - is mitigated) Pull Request resolved: pytorch#113163 Approved by: https://github.com/bdhirsh
1 parent 8c704f7 commit e6eab49

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/dynamo/test_repros.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3497,6 +3497,33 @@ def test_addr_alpha_beta_out(self):
34973497
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
34983498
self.assertTrue(same(out, compile_out))
34993499

3500+
def test_setattr_requires_grad_graph_breaks(self):
3501+
def fn(x):
3502+
z = x + 4
3503+
x.requires_grad = True
3504+
y = x * z
3505+
return y
3506+
3507+
for backend in ["count", "eager", "aot_eager"]:
3508+
if backend == "count":
3509+
backend = CompileCounter()
3510+
opt_fn = torch.compile(fn, backend=backend)
3511+
3512+
eager = torch.zeros(5)
3513+
compiled = eager.clone()
3514+
3515+
out_eager = fn(eager)
3516+
out_opt = opt_fn(compiled)
3517+
3518+
self.assertEqual(out_eager, out_opt)
3519+
3520+
out_eager.sum().backward()
3521+
out_opt.sum().backward()
3522+
3523+
self.assertEqual(eager, compiled)
3524+
if isinstance(backend, CompileCounter):
3525+
self.assertEqual(backend.frame_count, 2) # graph breaks
3526+
35003527
def test_inductor_no_recursionerror_on_for_loops(self):
35013528
def forward(x):
35023529
for _ in range(1000):

torch/_dynamo/variables/builtin.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1199,7 +1199,13 @@ def call_setattr(
11991199
tx.output.side_effects.is_attribute_mutation(obj)
12001200
and name_var.is_python_constant()
12011201
):
1202-
tx.output.side_effects.store_attr(obj, name_var.as_python_constant(), val)
1202+
name = name_var.as_python_constant()
1203+
if name == "requires_grad" and isinstance(obj, variables.TensorVariable):
1204+
unimplemented(
1205+
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in "
1206+
"the middle of the graph, which aot_autograd does not currently know how to handle. "
1207+
)
1208+
tx.output.side_effects.store_attr(obj, name, val)
12031209
return val
12041210
elif isinstance(obj, variables.UserDefinedObjectVariable):
12051211
unimplemented(

0 commit comments

Comments
 (0)