Skip to content

Commit 0f47476

Browse files
committed
Fix test formatting with autofix
1 parent f7286e2 commit 0f47476

File tree

1 file changed

+15
-17
lines changed

1 file changed

+15
-17
lines changed

tests/test_dice_focal_loss.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,28 +91,25 @@ def test_script(self):
9191
test_input = torch.ones(2, 1, 8, 8)
9292
test_script_save(loss, test_input, test_input)
9393

94-
@parameterized.expand([
95-
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
96-
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
97-
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
98-
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
99-
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
100-
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
101-
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
102-
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
103-
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
104-
])
94+
@parameterized.expand(
95+
[
96+
("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
97+
("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
98+
("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
99+
("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
100+
("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
101+
("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
102+
("none_None_0.5_0.25", "none", None, 0.5, 0.25),
103+
("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
104+
("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
105+
]
106+
)
105107
def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
106108
size = [3, 3, 5, 5]
107109
label = torch.randint(low=0, high=2, size=size)
108110
pred = torch.randn(size)
109111

110-
common_params = {
111-
"include_background": True,
112-
"to_onehot_y": False,
113-
"reduction": reduction,
114-
"weight": weight,
115-
}
112+
common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight}
116113

117114
dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
118115
dice = DiceLoss(**common_params)
@@ -123,5 +120,6 @@ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
123120

124121
np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")
125122

123+
126124
if __name__ == "__main__":
127125
unittest.main()

0 commit comments

Comments
 (0)