Skip to content

Commit 48bd6d7

Browse files
committed
Expose API to specify custom context manager for checkpoint
ghstack-source-id: 7bf8d4b Pull Request resolved: #96783
1 parent 4c39e8e commit 48bd6d7

File tree

2 files changed

+68
-6
lines changed

2 files changed

+68
-6
lines changed

test/test_autograd.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5608,6 +5608,39 @@ def foo(x, y, z):
56085608
out = checkpoint(foo, x, y, z, use_reentrant=False)
56095609
out.sum().backward()
56105610

5611+
def test_checkpointing_without_reentrant_with_context_fn(self):
5612+
class VerboseTorchDispatchMode(TorchDispatchMode):
5613+
def __init__(self):
5614+
self.operators = []
5615+
5616+
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
5617+
if kwargs is None:
5618+
kwargs = {}
5619+
self.operators.append(func.__name__)
5620+
return func(*args, **kwargs)
5621+
5622+
x = torch.tensor(1., requires_grad=True)
5623+
verbose_mode = VerboseTorchDispatchMode()
5624+
5625+
def context_fn():
5626+
return verbose_mode, contextlib.nullcontext()
5627+
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
5628+
self.assertEqual(verbose_mode.operators, ['sin.default'])
5629+
5630+
verbose_mode.operators = []
5631+
5632+
def context_fn():
5633+
return contextlib.nullcontext(), verbose_mode
5634+
out = checkpoint(lambda x: x.sin(), x, use_reentrant=False, context_fn=context_fn)
5635+
out.backward()
5636+
self.assertEqual(
5637+
verbose_mode.operators,
5638+
['detach.default', 'detach.default', 'detach.default', 'detach.default', 'sin.default']
5639+
)
5640+
5641+
with self.assertRaisesRegex(Exception, "only supported when use_reentrant=False"):
5642+
out = checkpoint(lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn)
5643+
56115644
def test_access_saved_tensor_twice_without_recomputation_works(self):
56125645
count = [0]
56135646

torch/utils/checkpoint.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
import warnings
33
import weakref
44
from weakref import ReferenceType
5-
from typing import Any, Iterable, List, Tuple, Dict, Optional, DefaultDict
5+
from typing import Any, Callable, ContextManager, Iterable, List, Tuple, Dict, Optional, DefaultDict
66
from collections import defaultdict
77
import uuid
88
import contextlib
99

1010
__all__ = [
1111
"checkpoint", "checkpoint_sequential", "CheckpointFunction",
1212
"check_backward_validity", "detach_variable", "get_device_states",
13-
"set_device_states",
13+
"set_device_states", "noop_context_fn"
1414
]
1515

1616
def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
@@ -165,7 +165,17 @@ def backward(ctx, *args):
165165
return (None, None) + grads
166166

167167

168-
def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
168+
def noop_context_fn():
169+
return contextlib.nullcontext(), contextlib.nullcontext()
170+
171+
172+
def checkpoint(
173+
function,
174+
*args,
175+
use_reentrant: bool = True,
176+
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
177+
**kwargs
178+
):
169179
r"""Checkpoint a model or part of the model
170180
171181
Checkpointing works by trading compute for memory. Rather than storing all
@@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
239249
keyword arguments input into the checkpointed function. Note that future
240250
versions of PyTorch will default to ``use_reentrant=False``.
241251
Default: ``True``
252+
context_fn(Callable, optional): A callable returning a tuple of two
253+
context managers. The function and its recomputation will be run
254+
under the first and second context managers respectively.
255+
This argument is only supported if ``use_reentrant=False``.
242256
args: tuple containing inputs to the :attr:`function`
243257
244258
Returns:
@@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
250264
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
251265

252266
if use_reentrant:
267+
if context_fn is not noop_context_fn:
268+
raise ValueError("Passing context_fn is only supported when use_reentrant=False.")
253269
return CheckpointFunction.apply(function, preserve, *args)
254270
else:
255271
return _checkpoint_without_reentrant(
256272
function,
257273
preserve,
274+
context_fn,
258275
*args,
259276
**kwargs,
260277
)
@@ -626,7 +643,13 @@ def unpack_hook(holder):
626643

627644
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
628645
# saving/restoring of global state is handled here.
629-
def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
646+
def _checkpoint_without_reentrant(
647+
fn,
648+
preserve_rng_state=True,
649+
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
650+
*args,
651+
**kwargs
652+
):
630653
"""Checkpointining without re-entrant autograd
631654
Args:
632655
function: describes what to run in the forward pass of the model or
@@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
637660
preserve_rng_state(bool, optional): Omit stashing and restoring
638661
the RNG state during each checkpoint.
639662
Default: ``True``
663+
context_fn(Callable, optional): A callable returning a tuple of two
664+
context managers. The function and its recomputation will be run
665+
under the first and second context managers respectively.
640666
*args: Arguments to pass in to the given ``function``.
641667
**kwargs: Keyword arguments to pass into the given ``function``.
642668
"""
669+
forward_context, recompute_context = context_fn()
643670
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
644671
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
645672

@@ -669,7 +696,8 @@ def recompute_fn(*inputs):
669696
set_device_states(fwd_gpu_devices, fwd_gpu_states)
670697

671698
with torch.cuda.amp.autocast(**gpu_autocast_kwargs), \
672-
torch.cpu.amp.autocast(**cpu_autocast_kwargs):
699+
torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
700+
recompute_context:
673701
fn(*args, **kwargs)
674702

675703
new_frame = _CheckpointFrame(recompute_fn)
@@ -680,7 +708,8 @@ def recompute_fn(*inputs):
680708
if new_frame.input_saver.grad_fn is None:
681709
return fn(*args, **kwargs)
682710

683-
with _checkpoint_hook(new_frame):
711+
with _checkpoint_hook(new_frame), \
712+
forward_context:
684713
ret = fn(*args, **kwargs)
685714

686715
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:

0 commit comments

Comments
 (0)