Skip to content

Commit f1db865

Browse files
Nic-Mawyli
andauthored
1939 Add strict_shape option in CheckpointLoader (#1946)
* [DLMED] add strict_shape option Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit tests Signed-off-by: Nic Ma <[email protected]> * update test case Signed-off-by: Wenqi Li <[email protected]> * fixes test config Signed-off-by: Wenqi Li <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent 713490b commit f1db865

File tree

3 files changed

+49
-4
lines changed

3 files changed

+49
-4
lines changed

.github/workflows/pythonapp.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ jobs:
285285
- name: Run quick tests (GPU)
286286
run: |
287287
nvidia-smi
288-
export LAUNCH_DELAY=$(( RANDOM % 30 * 5 ))
288+
export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 5)")
289289
echo "Sleep $LAUNCH_DELAY"
290290
sleep $LAUNCH_DELAY
291291
export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils)
@@ -298,7 +298,7 @@ jobs:
298298
python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
299299
python -c "import monai; monai.config.print_config()"
300300
BUILD_MONAI=1 ./runtests.sh --quick --unittests
301-
if [ ${{ matrix.environment }} == "PT18+CUDA112" ]; then
301+
if [ ${{ matrix.environment }} = "PT18+CUDA112" ]; then
302302
# test the clang-format tool downloading once
303303
coverage run -m tests.clang_format_utils
304304
fi

monai/handlers/checkpoint_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import TYPE_CHECKING, Dict, Optional
1414

1515
import torch
16+
import torch.nn as nn
1617

1718
from monai.utils import exact_version, optional_import
1819

@@ -44,8 +45,12 @@ class CheckpointLoader:
4445
first load the module to CPU and then copy each parameter to where it was
4546
saved, which would result in all processes on the same machine using the
4647
same set of devices.
47-
strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys
48-
returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
48+
strict: whether to strictly enforce that the keys in `state_dict` match the keys
49+
returned by `torch.nn.Module.state_dict` function. default to `True`.
50+
strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,
51+
`if `False`, it will skip the layers that have different data shape with checkpoint content.
52+
This can be useful advanced feature for transfer learning. users should totally
53+
understand which layers will have different shape. default to `True`.
4954
5055
"""
5156

@@ -56,6 +61,7 @@ def __init__(
5661
name: Optional[str] = None,
5762
map_location: Optional[Dict] = None,
5863
strict: bool = True,
64+
strict_shape: bool = True,
5965
) -> None:
6066
if load_path is None:
6167
raise AssertionError("must provide clear path to load checkpoint.")
@@ -67,6 +73,7 @@ def __init__(
6773
self._name = name
6874
self.map_location = map_location
6975
self.strict = strict
76+
self.strict_shape = strict_shape
7077

7178
def attach(self, engine: Engine) -> None:
7279
"""
@@ -84,6 +91,20 @@ def __call__(self, engine: Engine) -> None:
8491
"""
8592
checkpoint = torch.load(self.load_path, map_location=self.map_location)
8693

94+
if not self.strict_shape:
95+
k, _ = list(self.load_dict.items())[0]
96+
# single object and checkpoint is directly a state_dict
97+
if len(self.load_dict) == 1 and k not in checkpoint:
98+
checkpoint = {k: checkpoint}
99+
100+
# skip items that don't match data shape
101+
for k, obj in self.load_dict.items():
102+
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
103+
obj = obj.module
104+
if isinstance(obj, torch.nn.Module):
105+
d = obj.state_dict()
106+
checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape}
107+
87108
# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
88109
prior_max_epochs = engine.state.max_epochs
89110
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)

tests/test_handler_checkpoint_loader.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,30 @@ def test_partial_over_load(self):
146146
engine.run([0] * 8, max_epochs=1)
147147
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))
148148

149+
def test_strict_shape(self):
150+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
151+
net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)])
152+
data1 = net1.state_dict()
153+
data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5])
154+
data1["new"] = torch.tensor(0.1)
155+
net1.load_state_dict(data1, strict=False)
156+
157+
net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
158+
data2 = net2.state_dict()
159+
data2["0.weight"] = torch.tensor([0.2])
160+
data2["1.weight"] = torch.tensor([0.3])
161+
net2.load_state_dict(data2)
162+
163+
with tempfile.TemporaryDirectory() as tempdir:
164+
engine = Engine(lambda e, b: None)
165+
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
166+
engine.run([0] * 8, max_epochs=5)
167+
path = tempdir + "/net_final_iteration=40.pt"
168+
engine = Engine(lambda e, b: None)
169+
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine)
170+
engine.run([0] * 8, max_epochs=1)
171+
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2]))
172+
149173

150174
if __name__ == "__main__":
151175
unittest.main()

0 commit comments

Comments
 (0)