Skip to content

Commit dd84afc

Browse files
committed
Add channelwise flag to perceputal loss
1 parent 0250284 commit dd84afc

File tree

2 files changed

+60
-9
lines changed

2 files changed

+60
-9
lines changed

monai/losses/perceptual.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(
7474
pretrained: bool = True,
7575
pretrained_path: str | None = None,
7676
pretrained_state_dict_key: str | None = None,
77+
channelwise: bool = False,
7778
):
7879
super().__init__()
7980

@@ -102,15 +103,18 @@ def __init__(
102103
self.spatial_dims = spatial_dims
103104
self.perceptual_function: nn.Module
104105
if spatial_dims == 3 and is_fake_3d is False:
105-
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
106+
self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False,
107+
channelwise=channelwise)
106108
elif "radimagenet_" in network_type:
107-
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
109+
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False,
110+
channelwise=channelwise)
108111
elif network_type == "resnet50":
109112
self.perceptual_function = TorchvisionModelPerceptualSimilarity(
110113
net=network_type,
111114
pretrained=pretrained,
112115
pretrained_path=pretrained_path,
113116
pretrained_state_dict_key=pretrained_state_dict_key,
117+
channelwise=channelwise,
114118
)
115119
else:
116120
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
@@ -185,14 +189,21 @@ class MedicalNetPerceptualSimilarity(nn.Module):
185189
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
186190
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
187191
verbose: if false, mute messages from torch Hub load function.
192+
channelwise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
193+
Defaults to ``False``.
188194
"""
189195

190-
def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
196+
def __init__(self,
197+
net: str = "medicalnet_resnet10_23datasets",
198+
verbose: bool = False,
199+
channelwise: bool = False) -> None:
191200
super().__init__()
192201
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
193202
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
194203
self.eval()
195204

205+
self.channelwise = channelwise
206+
196207
for param in self.parameters():
197208
param.requires_grad = False
198209

@@ -206,6 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
206217
Args:
207218
input: 3D input tensor with shape BCDHW.
208219
target: 3D target tensor with shape BCDHW.
220+
209221
"""
210222
input = medicalnet_intensity_normalisation(input)
211223
target = medicalnet_intensity_normalisation(target)
@@ -227,7 +239,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
227239
feats_target = normalize_tensor(outs_target)
228240

229241
results: torch.Tensor = (feats_input - feats_target) ** 2
230-
results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
242+
243+
if self.channelwise:
244+
results = results.sum(dim=1, keepdim=True)
245+
results = spatial_average_3d(results, keepdim=True)
231246

232247
return results
233248

@@ -260,11 +275,13 @@ class RadImageNetPerceptualSimilarity(nn.Module):
260275
verbose: if false, mute messages from torch Hub load function.
261276
"""
262277

263-
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None:
278+
def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, channelwise: bool = False) -> None:
264279
super().__init__()
265280
self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose)
266281
self.eval()
267282

283+
self.channelwise = channelwise
284+
268285
for param in self.parameters():
269286
param.requires_grad = False
270287

@@ -297,7 +314,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
297314
feats_target = normalize_tensor(outs_target)
298315

299316
results: torch.Tensor = (feats_input - feats_target) ** 2
300-
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
317+
318+
if self.channelwise:
319+
results = results.sum(dim=1, keepdim=True)
320+
results = spatial_average(results, keepdim=True)
301321

302322
return results
303323

@@ -324,6 +344,7 @@ def __init__(
324344
pretrained: bool = True,
325345
pretrained_path: str | None = None,
326346
pretrained_state_dict_key: str | None = None,
347+
channelwise: bool = False,
327348
) -> None:
328349
super().__init__()
329350
supported_networks = ["resnet50"]
@@ -347,6 +368,8 @@ def __init__(
347368
self.model = torchvision.models.feature_extraction.create_feature_extractor(network, [self.final_layer])
348369
self.eval()
349370

371+
self.channelwise = channelwise
372+
350373
for param in self.parameters():
351374
param.requires_grad = False
352375

@@ -376,7 +399,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
376399
feats_target = normalize_tensor(outs_target)
377400

378401
results: torch.Tensor = (feats_input - feats_target) ** 2
379-
results = spatial_average(results.sum(dim=1, keepdim=True), keepdim=True)
402+
403+
if self.channelwise:
404+
results = results.sum(dim=1, keepdim=True)
405+
results = spatial_average(results, keepdim=True)
380406

381407
return results
382408

tests/test_perceptual_loss.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@
3030
],
3131
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 1, 64, 64), (2, 1, 64, 64)],
3232
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50"}, (2, 3, 64, 64), (2, 3, 64, 64)],
33+
[{"spatial_dims": 2, "network_type": "radimagenet_resnet50", "channelwise": True}, (2, 3, 64, 64), (2, 3, 64, 64)],
3334
[
3435
{"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1},
3536
(2, 1, 64, 64, 64),
3637
(2, 1, 64, 64, 64),
3738
],
39+
[
40+
{"spatial_dims": 3, "network_type": "radimagenet_resnet50", "is_fake_3d": True, "fake_3d_ratio": 0.1,
41+
'channelwise': True},
42+
(2, 1, 64, 64, 64),
43+
(2, 1, 64, 64, 64),
44+
],
3845
[
3946
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
4047
(2, 1, 64, 64, 64),
@@ -45,6 +52,11 @@
4552
(2, 6, 64, 64, 64),
4653
(2, 6, 64, 64, 64),
4754
],
55+
[
56+
{"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False, "channelwise": True},
57+
(2, 6, 64, 64, 64),
58+
(2, 6, 64, 64, 64),
59+
],
4860
[
4961
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
5062
(2, 1, 64, 64, 64),
@@ -60,6 +72,11 @@
6072
(2, 1, 64, 64, 64),
6173
(2, 1, 64, 64, 64),
6274
],
75+
[
76+
{"spatial_dims": 3, "network_type": "resnet50", "pretrained": True, "fake_3d_ratio": 0.2, "channelwise": True},
77+
(2, 3, 64, 64, 64),
78+
(2, 3, 64, 64, 64),
79+
],
6380
]
6481

6582

@@ -73,15 +90,23 @@ def test_shape(self, input_param, input_shape, target_shape):
7390
with skip_if_downloading_fails():
7491
loss = PerceptualLoss(**input_param)
7592
result = loss(torch.randn(input_shape), torch.randn(target_shape))
76-
self.assertEqual(result.shape, torch.Size([]))
93+
94+
if 'channelwise' in input_param.keys() and input_param['channelwise']:
95+
self.assertEqual(result.shape, torch.Size([input_shape[1]]))
96+
else:
97+
self.assertEqual(result.shape, torch.Size([]))
7798

7899
@parameterized.expand(TEST_CASES)
79100
def test_identical_input(self, input_param, input_shape, target_shape):
80101
with skip_if_downloading_fails():
81102
loss = PerceptualLoss(**input_param)
82103
tensor = torch.randn(input_shape)
83104
result = loss(tensor, tensor)
84-
self.assertEqual(result, torch.Tensor([0.0]))
105+
106+
if 'channelwise' in input_param.keys() and input_param['channelwise']:
107+
self.assertEqual(result, torch.Tensor([0.0] * input_shape[1]))
108+
else:
109+
self.assertEqual(result, torch.Tensor([0.0]))
85110

86111
def test_different_shape(self):
87112
with skip_if_downloading_fails():

0 commit comments

Comments
 (0)