Skip to content

Commit 85cdf47

Browse files
williamwen42jhavukainen
authored andcommitted
[dynamo 3.11] changes to with contexts (pytorch#94101)
Pull Request resolved: pytorch#94101 Approved by: https://github.com/albanD, https://github.com/jansel
1 parent 3b733a7 commit 85cdf47

File tree

2 files changed

+47
-14
lines changed

2 files changed

+47
-14
lines changed

torch/_dynamo/resume_execution.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ def __call__(self, code_options, cleanup):
9797
]
9898

9999
else:
100-
# NOTE: copying over for now since more changes are anticipated
101-
with_except_start = create_instruction("WITH_EXCEPT_START")
102100
pop_top_after_with_except_start = create_instruction("POP_TOP")
103-
104101
cleanup_complete_jump_target = create_instruction("NOP")
105102

106103
def create_load_none():
@@ -110,7 +107,6 @@ def create_load_none():
110107

111108
cleanup[:] = (
112109
[
113-
create_instruction("POP_BLOCK"),
114110
create_load_none(),
115111
create_load_none(),
116112
create_load_none(),
@@ -121,24 +117,27 @@ def create_load_none():
121117
create_instruction(
122118
"JUMP_FORWARD", target=cleanup_complete_jump_target
123119
),
124-
with_except_start,
120+
create_instruction("PUSH_EXC_INFO"),
121+
create_instruction("WITH_EXCEPT_START"),
125122
create_instruction(
126123
"POP_JUMP_FORWARD_IF_TRUE",
127124
target=pop_top_after_with_except_start,
128125
),
129-
create_instruction("RERAISE"),
126+
create_instruction("RERAISE", 2),
127+
create_instruction("COPY", 3),
128+
create_instruction("POP_EXCEPT"),
129+
create_instruction("RERAISE", 1),
130130
pop_top_after_with_except_start,
131-
create_instruction("POP_TOP"),
132-
create_instruction("POP_TOP"),
133131
create_instruction("POP_EXCEPT"),
134132
create_instruction("POP_TOP"),
133+
create_instruction("POP_TOP"),
135134
cleanup_complete_jump_target,
136135
]
137136
+ cleanup
138137
)
139138

140139
return create_call_function(0, False) + [
141-
create_instruction("SETUP_WITH", target=with_except_start),
140+
create_instruction("BEFORE_WITH"),
142141
create_instruction("POP_TOP"),
143142
]
144143

torch/_dynamo/symbolic_convert.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def _step_logger():
102102

103103
@dataclasses.dataclass
104104
class BlockStackEntry:
105+
id: int
105106
target: Instruction
106107
stack_index: Optional[int] = None
107108
with_context: ContextWrappingVariable = None
@@ -878,11 +879,11 @@ def jump(self, inst):
878879

879880
def SETUP_LOOP(self, inst):
880881
# only exists in python<=3.7
881-
self.block_stack.append(BlockStackEntry(inst.target))
882+
self.block_stack.append(BlockStackEntry(0, inst.target))
882883

883884
def SETUP_EXCEPT(self, inst):
884885
# only exists in python<=3.7
885-
self.block_stack.append(BlockStackEntry(inst.target))
886+
self.block_stack.append(BlockStackEntry(0, inst.target))
886887

887888
def POP_BLOCK(self, inst):
888889
self.block_stack.pop()
@@ -894,10 +895,12 @@ def SETUP_WITH(self, inst):
894895
self.output.guards.update(ctx.guards)
895896

896897
if isinstance(self, InstructionTranslator):
897-
self.block_stack.append(BlockStackEntry(inst.target, len(self.stack), ctx))
898+
self.block_stack.append(
899+
BlockStackEntry(0, inst.target, len(self.stack), ctx)
900+
)
898901
else:
899902
# can't restore this while inlining
900-
self.block_stack.append(BlockStackEntry(inst.target))
903+
self.block_stack.append(BlockStackEntry(0, inst.target))
901904
self.push(
902905
WithExitFunctionVariable(
903906
ctx,
@@ -908,7 +911,7 @@ def SETUP_WITH(self, inst):
908911
self.push(ctx.enter(self))
909912

910913
def SETUP_FINALLY(self, inst):
911-
self.block_stack.append(BlockStackEntry(inst.target))
914+
self.block_stack.append(BlockStackEntry(0, inst.target))
912915

913916
def BEGIN_FINALLY(self, inst):
914917
self.push(None)
@@ -1569,6 +1572,13 @@ def CALL(self, inst):
15691572
kwargs = {}
15701573
self.call_function(fn, args, kwargs)
15711574
self.kw_names = None
1575+
# 3.11 removed POP_BLOCK, so we manually pop the block stack here
1576+
if (
1577+
isinstance(fn, WithExitFunctionVariable)
1578+
and len(self.block_stack) > 0
1579+
and id(fn) == self.block_stack[-1].id
1580+
):
1581+
self.block_stack.pop()
15721582

15731583
def COPY(self, inst):
15741584
self.push(self.stack[-inst.arg])
@@ -1592,6 +1602,30 @@ def SWAP(self, inst):
15921602
def CACHE(self, inst):
15931603
pass
15941604

1605+
def BEFORE_WITH(self, inst):
1606+
ctx = self.pop()
1607+
if not isinstance(ctx, ContextWrappingVariable):
1608+
unimplemented(f"BEFORE_WITH {ctx}")
1609+
self.output.guards.update(ctx.guards)
1610+
1611+
exit = WithExitFunctionVariable(
1612+
ctx,
1613+
inst.target,
1614+
**VariableTracker.propagate(ctx),
1615+
)
1616+
# 3.11 no longer uses a block stack, but we still keep track of one
1617+
# so that we know which contexts are currently active.
1618+
if isinstance(self, InstructionTranslator):
1619+
self.block_stack.append(
1620+
BlockStackEntry(id(exit), inst.target, self.real_stack_len(), ctx)
1621+
)
1622+
else:
1623+
# can't restore this while inlining
1624+
self.block_stack.append(BlockStackEntry(id(exit), inst.target))
1625+
1626+
self.push(exit)
1627+
self.push(ctx.enter(self))
1628+
15951629
def copy_graphstate(self) -> InstructionTranslatorGraphState:
15961630
"""Create a checkpoint of the current state by copying everything"""
15971631
return InstructionTranslatorGraphState(

0 commit comments

Comments
 (0)