2
2
import warnings
3
3
import weakref
4
4
from weakref import ReferenceType
5
- from typing import Any , Iterable , List , Tuple , Dict , Optional , DefaultDict
5
+ from typing import Any , Callable , ContextManager , Iterable , List , Tuple , Dict , Optional , DefaultDict
6
6
from collections import defaultdict
7
7
import uuid
8
8
import contextlib
9
9
10
10
__all__ = [
11
11
"checkpoint" , "checkpoint_sequential" , "CheckpointFunction" ,
12
12
"check_backward_validity" , "detach_variable" , "get_device_states" ,
13
- "set_device_states" ,
13
+ "set_device_states" , "noop_context_fn"
14
14
]
15
15
16
16
def detach_variable (inputs : Tuple [Any , ...]) -> Tuple [torch .Tensor , ...]:
@@ -165,7 +165,17 @@ def backward(ctx, *args):
165
165
return (None , None ) + grads
166
166
167
167
168
- def checkpoint (function , * args , use_reentrant : bool = True , ** kwargs ):
168
+ def noop_context_fn ():
169
+ return contextlib .nullcontext (), contextlib .nullcontext ()
170
+
171
+
172
+ def checkpoint (
173
+ function ,
174
+ * args ,
175
+ use_reentrant : bool = True ,
176
+ context_fn : Callable [[], Tuple [ContextManager , ContextManager ]] = noop_context_fn ,
177
+ ** kwargs
178
+ ):
169
179
r"""Checkpoint a model or part of the model
170
180
171
181
Checkpointing works by trading compute for memory. Rather than storing all
@@ -239,6 +249,10 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
239
249
keyword arguments input into the checkpointed function. Note that future
240
250
versions of PyTorch will default to ``use_reentrant=False``.
241
251
Default: ``True``
252
+ context_fn(Callable, optional): A callable returning a tuple of two
253
+ context managers. The function and its recomputation will be run
254
+ under the first and second context managers respectively.
255
+ This argument is only supported if ``use_reentrant=False``.
242
256
args: tuple containing inputs to the :attr:`function`
243
257
244
258
Returns:
@@ -250,11 +264,14 @@ def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
250
264
raise ValueError ("Unexpected keyword arguments: " + "," .join (arg for arg in kwargs ))
251
265
252
266
if use_reentrant :
267
+ if context_fn is not noop_context_fn :
268
+ raise ValueError ("Passing context_fn is only supported when use_reentrant=False." )
253
269
return CheckpointFunction .apply (function , preserve , * args )
254
270
else :
255
271
return _checkpoint_without_reentrant (
256
272
function ,
257
273
preserve ,
274
+ context_fn ,
258
275
* args ,
259
276
** kwargs ,
260
277
)
@@ -626,7 +643,13 @@ def unpack_hook(holder):
626
643
627
644
# NB: this helper wraps fn before calling checkpoint_impl. kwargs and
628
645
# saving/restoring of global state is handled here.
629
- def _checkpoint_without_reentrant (fn , preserve_rng_state = True , * args , ** kwargs ):
646
+ def _checkpoint_without_reentrant (
647
+ fn ,
648
+ preserve_rng_state = True ,
649
+ context_fn : Callable [[], Tuple [ContextManager , ContextManager ]] = noop_context_fn ,
650
+ * args ,
651
+ ** kwargs
652
+ ):
630
653
"""Checkpointining without re-entrant autograd
631
654
Args:
632
655
function: describes what to run in the forward pass of the model or
@@ -637,9 +660,13 @@ def _checkpoint_without_reentrant(fn, preserve_rng_state=True, *args, **kwargs):
637
660
preserve_rng_state(bool, optional): Omit stashing and restoring
638
661
the RNG state during each checkpoint.
639
662
Default: ``True``
663
+ context_fn(Callable, optional): A callable returning a tuple of two
664
+ context managers. The function and its recomputation will be run
665
+ under the first and second context managers respectively.
640
666
*args: Arguments to pass in to the given ``function``.
641
667
**kwargs: Keyword arguments to pass into the given ``function``.
642
668
"""
669
+ forward_context , recompute_context = context_fn ()
643
670
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
644
671
gpu_autocast_kwargs , cpu_autocast_kwargs = _get_autocast_kwargs ()
645
672
@@ -669,7 +696,8 @@ def recompute_fn(*inputs):
669
696
set_device_states (fwd_gpu_devices , fwd_gpu_states )
670
697
671
698
with torch .cuda .amp .autocast (** gpu_autocast_kwargs ), \
672
- torch .cpu .amp .autocast (** cpu_autocast_kwargs ):
699
+ torch .cpu .amp .autocast (** cpu_autocast_kwargs ), \
700
+ recompute_context :
673
701
fn (* args , ** kwargs )
674
702
675
703
new_frame = _CheckpointFrame (recompute_fn )
@@ -680,7 +708,8 @@ def recompute_fn(*inputs):
680
708
if new_frame .input_saver .grad_fn is None :
681
709
return fn (* args , ** kwargs )
682
710
683
- with _checkpoint_hook (new_frame ):
711
+ with _checkpoint_hook (new_frame ), \
712
+ forward_context :
684
713
ret = fn (* args , ** kwargs )
685
714
686
715
if torch .cuda ._initialized and preserve_rng_state and not had_cuda_in_fwd :
0 commit comments