Skip to content

Commit 46e2b0e

Browse files
KumoLiuericspod
andauthored
Fix load pretrain weight issue in ResNet (#7924)
Fixes #7923 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 7e4f141 commit 46e2b0e

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

monai/networks/nets/resnet.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,16 +510,15 @@ def _resnet(
510510
# Check model bias_downsample and shortcut_type
511511
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
512512
if shortcut_type == kwargs.get("shortcut_type", "B") and (
513-
bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
513+
bias_downsample == kwargs.get("bias_downsample", True)
514514
):
515515
# Download the MedicalNet pretrained model
516516
model_state_dict = get_pretrained_resnet_medicalnet(
517517
resnet_depth, device=device, datasets23=True
518518
)
519519
else:
520520
raise NotImplementedError(
521-
f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
522-
f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
521+
f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} "
523522
f"when using pretrained MedicalNet resnet{resnet_depth}"
524523
)
525524
else:
@@ -681,7 +680,7 @@ def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
681680
# After testing
682681
# False: 10, 50, 101, 152, 200
683682
# Any: 18, 34
684-
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
683+
bias_downsample = resnet_depth in (18, 34)
685684
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
686685
return bias_downsample, shortcut_type
687686

tests/test_resnet.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
266266
@parameterized.expand(PRETRAINED_TEST_CASES)
267267
@skip_if_quick
268268
@skip_if_no_cuda
269-
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
269+
def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape):
270270
net = model(**input_param).to(device)
271271
# Save ckpt
272272
torch.save(net.state_dict(), self.tmp_ckpt_filename)
@@ -290,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
290290
and input_param.get("n_input_channels", 3) == 1
291291
and input_param.get("feed_forward", True) is False
292292
and input_param.get("shortcut_type", "B") == shortcut_type
293-
and (
294-
input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
295-
)
293+
and (input_param.get("bias_downsample", True) == bias_downsample)
296294
):
297295
model(**cp_input_param)
298296
else:
@@ -303,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
303301
cp_input_param["n_input_channels"] = 1
304302
cp_input_param["feed_forward"] = False
305303
cp_input_param["shortcut_type"] = shortcut_type
306-
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
304+
cp_input_param["bias_downsample"] = bias_downsample
307305
if cp_input_param.get("spatial_dims", 3) == 3:
308306
with skip_if_downloading_fails():
309307
pretrained_net = model(**cp_input_param).to(device)

0 commit comments

Comments
 (0)