13
13
from typing import TYPE_CHECKING , Dict , Optional
14
14
15
15
import torch
16
+ import torch .nn as nn
16
17
17
18
from monai .utils import exact_version , optional_import
18
19
@@ -44,8 +45,12 @@ class CheckpointLoader:
44
45
first load the module to CPU and then copy each parameter to where it was
45
46
saved, which would result in all processes on the same machine using the
46
47
same set of devices.
47
- strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys
48
- returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
48
+ strict: whether to strictly enforce that the keys in `state_dict` match the keys
49
+ returned by `torch.nn.Module.state_dict` function. default to `True`.
50
+ strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,
51
+ `if `False`, it will skip the layers that have different data shape with checkpoint content.
52
+ This can be useful advanced feature for transfer learning. users should totally
53
+ understand which layers will have different shape. default to `True`.
49
54
50
55
"""
51
56
@@ -56,6 +61,7 @@ def __init__(
56
61
name : Optional [str ] = None ,
57
62
map_location : Optional [Dict ] = None ,
58
63
strict : bool = True ,
64
+ strict_shape : bool = True ,
59
65
) -> None :
60
66
if load_path is None :
61
67
raise AssertionError ("must provide clear path to load checkpoint." )
@@ -67,6 +73,7 @@ def __init__(
67
73
self ._name = name
68
74
self .map_location = map_location
69
75
self .strict = strict
76
+ self .strict_shape = strict_shape
70
77
71
78
def attach (self , engine : Engine ) -> None :
72
79
"""
@@ -84,6 +91,20 @@ def __call__(self, engine: Engine) -> None:
84
91
"""
85
92
checkpoint = torch .load (self .load_path , map_location = self .map_location )
86
93
94
+ if not self .strict_shape :
95
+ k , _ = list (self .load_dict .items ())[0 ]
96
+ # single object and checkpoint is directly a state_dict
97
+ if len (self .load_dict ) == 1 and k not in checkpoint :
98
+ checkpoint = {k : checkpoint }
99
+
100
+ # skip items that don't match data shape
101
+ for k , obj in self .load_dict .items ():
102
+ if isinstance (obj , (nn .DataParallel , nn .parallel .DistributedDataParallel )):
103
+ obj = obj .module
104
+ if isinstance (obj , torch .nn .Module ):
105
+ d = obj .state_dict ()
106
+ checkpoint [k ] = {k : v for k , v in checkpoint [k ].items () if k in d and v .shape == d [k ].shape }
107
+
87
108
# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
88
109
prior_max_epochs = engine .state .max_epochs
89
110
Checkpoint .load_objects (to_load = self .load_dict , checkpoint = checkpoint , strict = self .strict )
0 commit comments