@@ -74,6 +74,7 @@ def __init__(
74
74
pretrained : bool = True ,
75
75
pretrained_path : str | None = None ,
76
76
pretrained_state_dict_key : str | None = None ,
77
+ channelwise : bool = False ,
77
78
):
78
79
super ().__init__ ()
79
80
@@ -102,15 +103,18 @@ def __init__(
102
103
self .spatial_dims = spatial_dims
103
104
self .perceptual_function : nn .Module
104
105
if spatial_dims == 3 and is_fake_3d is False :
105
- self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False )
106
+ self .perceptual_function = MedicalNetPerceptualSimilarity (net = network_type , verbose = False ,
107
+ channelwise = channelwise )
106
108
elif "radimagenet_" in network_type :
107
- self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False )
109
+ self .perceptual_function = RadImageNetPerceptualSimilarity (net = network_type , verbose = False ,
110
+ channelwise = channelwise )
108
111
elif network_type == "resnet50" :
109
112
self .perceptual_function = TorchvisionModelPerceptualSimilarity (
110
113
net = network_type ,
111
114
pretrained = pretrained ,
112
115
pretrained_path = pretrained_path ,
113
116
pretrained_state_dict_key = pretrained_state_dict_key ,
117
+ channelwise = channelwise ,
114
118
)
115
119
else :
116
120
self .perceptual_function = LPIPS (pretrained = pretrained , net = network_type , verbose = False )
@@ -185,14 +189,21 @@ class MedicalNetPerceptualSimilarity(nn.Module):
185
189
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
186
190
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
187
191
verbose: if false, mute messages from torch Hub load function.
192
+ channelwise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
193
+ Defaults to ``False``.
188
194
"""
189
195
190
- def __init__ (self , net : str = "medicalnet_resnet10_23datasets" , verbose : bool = False ) -> None :
196
+ def __init__ (self ,
197
+ net : str = "medicalnet_resnet10_23datasets" ,
198
+ verbose : bool = False ,
199
+ channelwise : bool = False ) -> None :
191
200
super ().__init__ ()
192
201
torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
193
202
self .model = torch .hub .load ("warvito/MedicalNet-models" , model = net , verbose = verbose )
194
203
self .eval ()
195
204
205
+ self .channelwise = channelwise
206
+
196
207
for param in self .parameters ():
197
208
param .requires_grad = False
198
209
@@ -206,6 +217,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
206
217
Args:
207
218
input: 3D input tensor with shape BCDHW.
208
219
target: 3D target tensor with shape BCDHW.
220
+
209
221
"""
210
222
input = medicalnet_intensity_normalisation (input )
211
223
target = medicalnet_intensity_normalisation (target )
@@ -227,7 +239,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
227
239
feats_target = normalize_tensor (outs_target )
228
240
229
241
results : torch .Tensor = (feats_input - feats_target ) ** 2
230
- results = spatial_average_3d (results .sum (dim = 1 , keepdim = True ), keepdim = True )
242
+
243
+ if self .channelwise :
244
+ results = results .sum (dim = 1 , keepdim = True )
245
+ results = spatial_average_3d (results , keepdim = True )
231
246
232
247
return results
233
248
@@ -260,11 +275,13 @@ class RadImageNetPerceptualSimilarity(nn.Module):
260
275
verbose: if false, mute messages from torch Hub load function.
261
276
"""
262
277
263
- def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False ) -> None :
278
+ def __init__ (self , net : str = "radimagenet_resnet50" , verbose : bool = False , channelwise : bool = False ) -> None :
264
279
super ().__init__ ()
265
280
self .model = torch .hub .load ("Warvito/radimagenet-models" , model = net , verbose = verbose )
266
281
self .eval ()
267
282
283
+ self .channelwise = channelwise
284
+
268
285
for param in self .parameters ():
269
286
param .requires_grad = False
270
287
@@ -297,7 +314,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
297
314
feats_target = normalize_tensor (outs_target )
298
315
299
316
results : torch .Tensor = (feats_input - feats_target ) ** 2
300
- results = spatial_average (results .sum (dim = 1 , keepdim = True ), keepdim = True )
317
+
318
+ if self .channelwise :
319
+ results = results .sum (dim = 1 , keepdim = True )
320
+ results = spatial_average (results , keepdim = True )
301
321
302
322
return results
303
323
@@ -324,6 +344,7 @@ def __init__(
324
344
pretrained : bool = True ,
325
345
pretrained_path : str | None = None ,
326
346
pretrained_state_dict_key : str | None = None ,
347
+ channelwise : bool = False ,
327
348
) -> None :
328
349
super ().__init__ ()
329
350
supported_networks = ["resnet50" ]
@@ -347,6 +368,8 @@ def __init__(
347
368
self .model = torchvision .models .feature_extraction .create_feature_extractor (network , [self .final_layer ])
348
369
self .eval ()
349
370
371
+ self .channelwise = channelwise
372
+
350
373
for param in self .parameters ():
351
374
param .requires_grad = False
352
375
@@ -376,7 +399,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
376
399
feats_target = normalize_tensor (outs_target )
377
400
378
401
results : torch .Tensor = (feats_input - feats_target ) ** 2
379
- results = spatial_average (results .sum (dim = 1 , keepdim = True ), keepdim = True )
402
+
403
+ if self .channelwise :
404
+ results = results .sum (dim = 1 , keepdim = True )
405
+ results = spatial_average (results , keepdim = True )
380
406
381
407
return results
382
408
0 commit comments