Skip to content

1939 Add strict_shape option in CheckpointLoader #1946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pythonapp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ jobs:
- name: Run quick tests (GPU)
run: |
nvidia-smi
export LAUNCH_DELAY=$(( RANDOM % 30 * 5 ))
export LAUNCH_DELAY=$(python -c "import numpy; print(numpy.random.randint(30) * 5)")
echo "Sleep $LAUNCH_DELAY"
sleep $LAUNCH_DELAY
export CUDA_VISIBLE_DEVICES=$(coverage run -m tests.utils)
Expand All @@ -298,7 +298,7 @@ jobs:
python -c 'import torch; print(torch.rand(5, 3, device=torch.device("cuda:0")))'
python -c "import monai; monai.config.print_config()"
BUILD_MONAI=1 ./runtests.sh --quick --unittests
if [ ${{ matrix.environment }} == "PT18+CUDA112" ]; then
if [ ${{ matrix.environment }} = "PT18+CUDA112" ]; then
# test the clang-format tool downloading once
coverage run -m tests.clang_format_utils
fi
Expand Down
25 changes: 23 additions & 2 deletions monai/handlers/checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import TYPE_CHECKING, Dict, Optional

import torch
import torch.nn as nn

from monai.utils import exact_version, optional_import

Expand Down Expand Up @@ -44,8 +45,12 @@ class CheckpointLoader:
first load the module to CPU and then copy each parameter to where it was
saved, which would result in all processes on the same machine using the
same set of devices.
strict: whether to strictly enforce that the keys in :attr:`state_dict` match the keys
returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``True``
strict: whether to strictly enforce that the keys in `state_dict` match the keys
returned by `torch.nn.Module.state_dict` function. default to `True`.
strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,
`if `False`, it will skip the layers that have different data shape with checkpoint content.
This can be useful advanced feature for transfer learning. users should totally
understand which layers will have different shape. default to `True`.

"""

Expand All @@ -56,6 +61,7 @@ def __init__(
name: Optional[str] = None,
map_location: Optional[Dict] = None,
strict: bool = True,
strict_shape: bool = True,
) -> None:
if load_path is None:
raise AssertionError("must provide clear path to load checkpoint.")
Expand All @@ -67,6 +73,7 @@ def __init__(
self._name = name
self.map_location = map_location
self.strict = strict
self.strict_shape = strict_shape

def attach(self, engine: Engine) -> None:
"""
Expand All @@ -84,6 +91,20 @@ def __call__(self, engine: Engine) -> None:
"""
checkpoint = torch.load(self.load_path, map_location=self.map_location)

if not self.strict_shape:
k, _ = list(self.load_dict.items())[0]
# single object and checkpoint is directly a state_dict
if len(self.load_dict) == 1 and k not in checkpoint:
checkpoint = {k: checkpoint}

# skip items that don't match data shape
for k, obj in self.load_dict.items():
if isinstance(obj, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
obj = obj.module
if isinstance(obj, torch.nn.Module):
d = obj.state_dict()
checkpoint[k] = {k: v for k, v in checkpoint[k].items() if k in d and v.shape == d[k].shape}

# save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
prior_max_epochs = engine.state.max_epochs
Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_handler_checkpoint_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ def test_partial_over_load(self):
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.1]))

def test_strict_shape(self):
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
net1 = torch.nn.Sequential(*[torch.nn.PReLU(num_parameters=5)])
data1 = net1.state_dict()
data1["0.weight"] = torch.tensor([1, 2, 3, 4, 5])
data1["new"] = torch.tensor(0.1)
net1.load_state_dict(data1, strict=False)

net2 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
data2 = net2.state_dict()
data2["0.weight"] = torch.tensor([0.2])
data2["1.weight"] = torch.tensor([0.3])
net2.load_state_dict(data2)

with tempfile.TemporaryDirectory() as tempdir:
engine = Engine(lambda e, b: None)
CheckpointSaver(save_dir=tempdir, save_dict={"net": net1}, save_final=True).attach(engine)
engine.run([0] * 8, max_epochs=5)
path = tempdir + "/net_final_iteration=40.pt"
engine = Engine(lambda e, b: None)
CheckpointLoader(load_path=path, load_dict={"net": net2}, strict=False, strict_shape=False).attach(engine)
engine.run([0] * 8, max_epochs=1)
torch.testing.assert_allclose(net2.state_dict()["0.weight"].cpu(), torch.tensor([0.2]))


if __name__ == "__main__":
unittest.main()