@@ -91,28 +91,25 @@ def test_script(self):
91
91
test_input = torch .ones (2 , 1 , 8 , 8 )
92
92
test_script_save (loss , test_input , test_input )
93
93
94
- @parameterized .expand ([
95
- ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
96
- ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
97
- ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
98
- ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
99
- ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
100
- ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
101
- ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
102
- ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
103
- ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
104
- ])
94
+ @parameterized .expand (
95
+ [
96
+ ("sum_None_0.5_0.25" , "sum" , None , 0.5 , 0.25 ),
97
+ ("sum_weight_0.5_0.25" , "sum" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
98
+ ("sum_weight_tuple_0.5_0.25" , "sum" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
99
+ ("mean_None_0.5_0.25" , "mean" , None , 0.5 , 0.25 ),
100
+ ("mean_weight_0.5_0.25" , "mean" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
101
+ ("mean_weight_tuple_0.5_0.25" , "mean" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
102
+ ("none_None_0.5_0.25" , "none" , None , 0.5 , 0.25 ),
103
+ ("none_weight_0.5_0.25" , "none" , torch .tensor ([1.0 , 1.0 , 2.0 ]), 0.5 , 0.25 ),
104
+ ("none_weight_tuple_0.5_0.25" , "none" , (3 , 2.0 , 1 ), 0.5 , 0.25 ),
105
+ ]
106
+ )
105
107
def test_with_alpha (self , name , reduction , weight , lambda_focal , alpha ):
106
108
size = [3 , 3 , 5 , 5 ]
107
109
label = torch .randint (low = 0 , high = 2 , size = size )
108
110
pred = torch .randn (size )
109
111
110
- common_params = {
111
- "include_background" : True ,
112
- "to_onehot_y" : False ,
113
- "reduction" : reduction ,
114
- "weight" : weight ,
115
- }
112
+ common_params = {"include_background" : True , "to_onehot_y" : False , "reduction" : reduction , "weight" : weight }
116
113
117
114
dice_focal = DiceFocalLoss (gamma = 1.0 , lambda_focal = lambda_focal , alpha = alpha , ** common_params )
118
115
dice = DiceLoss (** common_params )
@@ -123,5 +120,6 @@ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
123
120
124
121
np .testing .assert_allclose (result , expected_val , err_msg = f"Failed on case: { name } " )
125
122
123
+
126
124
if __name__ == "__main__" :
127
125
unittest .main ()
0 commit comments