Skip to content

Commit 03a5fa6

Browse files
MedicalNetPerceptualSimilarity: Add multi-channel (#7568)
Fixes #7567 . ### Description MedicalNetPerceptualSimilarity: Add multi-channel support for 3Dvolumes. The current version of the code in the dev branch already largely supports that besides the following: medicalnet_* require inputs to have a single channel. This PR passes the multi-channel volume channel-wise to the networks and concatenates the resulting feature vectors. The existing code takes care of averaging over channels and spatially. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Fabian Klopfer <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 7a6b69f commit 03a5fa6

File tree

2 files changed

+80
-10
lines changed

2 files changed

+80
-10
lines changed

monai/losses/perceptual.py

Lines changed: 49 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module):
4545
4646
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all
4747
three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.
48+
MedicalNet networks are only compatible with 3D inputs and support channel-wise loss.
4849
4950
Args:
5051
spatial_dims: number of spatial dimensions.
@@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module):
6263
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
6364
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
6465
Defaults to `None`.
66+
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
67+
Defaults to ``False``.
6568
"""
6669

6770
def __init__(
@@ -74,6 +77,7 @@ def __init__(
7477
pretrained: bool = True,
7578
pretrained_path: str | None = None,
7679
pretrained_state_dict_key: str | None = None,
80+
channel_wise: bool = False,
7781
):
7882
super().__init__()
7983

@@ -86,6 +90,9 @@ def __init__(
8690
"Argument is_fake_3d must be set to False."
8791
)
8892

93+
if channel_wise and "medicalnet_" not in network_type:
94+
raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")
95+
8996
if network_type.lower() not in list(PercetualNetworkType):
9097
raise ValueError(
9198
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
@@ -102,7 +109,9 @@ def __init__(
102109
self.spatial_dims = spatial_dims
103110
self.perceptual_function: nn.Module
104111
if spatial_dims == 3 and is_fake_3d is False:
105-
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
112+
self.perceptual_function = MedicalNetPerceptualSimilarity(
113+
net=network_type, verbose=False, channel_wise=channel_wise
114+
)
106115
elif "radimagenet_" in network_type:
107116
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
108117
elif network_type == "resnet50":
@@ -172,7 +181,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
172181
# 2D and real 3D cases
173182
loss = self.perceptual_function(input, target)
174183

175-
return torch.mean(loss)
184+
if self.channel_wise:
185+
loss = torch.mean(loss.squeeze(), dim=0)
186+
else:
187+
loss = torch.mean(loss)
188+
189+
return loss
176190

177191

178192
class MedicalNetPerceptualSimilarity(nn.Module):
@@ -185,14 +199,20 @@ class MedicalNetPerceptualSimilarity(nn.Module):
185199
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
186200
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
187201
verbose: if false, mute messages from torch Hub load function.
202+
channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
203+
Defaults to ``False``.
188204
"""
189205

190-
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
206+
def __init__(
207+
self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
208+
) -> None:
191209
super().__init__()
192210
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
193211
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
194212
self.eval()
195213

214+
self.channel_wise = channel_wise
215+
196216
for param in self.parameters():
197217
param.requires_grad = False
198218

@@ -206,20 +226,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
206226
Args:
207227
input: 3D input tensor with shape BCDHW.
208228
target: 3D target tensor with shape BCDHW.
229+
209230
"""
210231
input = medicalnet_intensity_normalisation(input)
211232
target = medicalnet_intensity_normalisation(target)
212233

213234
# Get model outputs
214-
outs_input = self.model.forward(input)
215-
outs_target = self.model.forward(target)
235+
feats_per_ch = 0
236+
for ch_idx in range(input.shape[1]):
237+
input_channel = input[:, ch_idx, ...].unsqueeze(1)
238+
target_channel = target[:, ch_idx, ...].unsqueeze(1)
239+
240+
if ch_idx == 0:
241+
outs_input = self.model.forward(input_channel)
242+
outs_target = self.model.forward(target_channel)
243+
feats_per_ch = outs_input.shape[1]
244+
else:
245+
outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)
246+
outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)
216247

217248
# Normalise through the channels
218249
feats_input = normalize_tensor(outs_input)
219250
feats_target = normalize_tensor(outs_target)
220251

221-
results: torch.Tensor = (feats_input - feats_target) ** 2
222-
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
252+
feats_diff: torch.Tensor = (feats_input - feats_target) ** 2
253+
if self.channel_wise:
254+
results = torch.zeros(
255+
feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]
256+
)
257+
for i in range(input.shape[1]):
258+
l_idx = i * feats_per_ch
259+
r_idx = (i + 1) * feats_per_ch
260+
results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
261+
else:
262+
results = feats_diff.sum(dim=1, keepdim=True)
263+
264+
results = spatial_average_3d(results, keepdim=True)
223265

224266
return results
225267

tests/test_perceptual_loss.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from monai.losses import PerceptualLoss
2020
from monai.utils import optional_import
21-
from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick
21+
from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick
2222

2323
_, has_torchvision = optional_import("torchvision")
2424
TEST_CASES = [
@@ -40,11 +40,31 @@
4040
(2, 1, 64, 64, 64),
4141
(2, 1, 64, 64, 64),
4242
],
43+
[
44+
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
45+
(2, 6, 64, 64, 64),
46+
(2, 6, 64, 64, 64),
47+
],
48+
[
49+
{
50+
"spatial_dims": 3,
51+
"network_type": "medicalnet_resnet10_23datasets",
52+
"is_fake_3d": False,
53+
"channel_wise": True,
54+
},
55+
(2, 6, 64, 64, 64),
56+
(2, 6, 64, 64, 64),
57+
],
4358
[
4459
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
4560
(2, 1, 64, 64, 64),
4661
(2, 1, 64, 64, 64),
4762
],
63+
[
64+
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
65+
(2, 6, 64, 64, 64),
66+
(2, 6, 64, 64, 64),
67+
],
4868
[
4969
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
5070
(2, 1, 64, 64, 64),
@@ -63,15 +83,23 @@ def test_shape(self, input_param, input_shape, target_shape):
6383
with skip_if_downloading_fails():
6484
loss = PerceptualLoss(**input_param)
6585
result = loss(torch.randn(input_shape), torch.randn(target_shape))
66-
self.assertEqual(result.shape, torch.Size([]))
86+
87+
if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
88+
self.assertEqual(result.shape, torch.Size([input_shape[1]]))
89+
else:
90+
self.assertEqual(result.shape, torch.Size([]))
6791

6892
@parameterized.expand(TEST_CASES)
6993
def test_identical_input(self, input_param, input_shape, target_shape):
7094
with skip_if_downloading_fails():
7195
loss = PerceptualLoss(**input_param)
7296
tensor = torch.randn(input_shape)
7397
result = loss(tensor, tensor)
74-
self.assertEqual(result, torch.Tensor([0.0]))
98+
99+
if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
100+
assert_allclose(result, torch.Tensor([0.0] * input_shape[1]))
101+
else:
102+
self.assertEqual(result, torch.Tensor([0.0]))
75103

76104
def test_different_shape(self):
77105
with skip_if_downloading_fails():

0 commit comments

Comments
 (0)