Skip to content

Commit 2ca9710

Browse files
authored
2746 Add round support to AsDiscrete transform (#2753)
* [DLMED] add round_values Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]>
1 parent a7d4574 commit 2ca9710

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

monai/transforms/post/array.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai.networks.layers import GaussianFilter
2626
from monai.transforms.transform import Transform
2727
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
28-
from monai.utils import ensure_tuple
28+
from monai.utils import ensure_tuple, look_up_option
2929

3030
__all__ = [
3131
"Activations",
@@ -112,7 +112,8 @@ class AsDiscrete(Transform):
112112
113113
- execute `argmax` for input logits values.
114114
- threshold input value to 0.0 or 1.0.
115-
- convert input value to One-Hot format
115+
- convert input value to One-Hot format.
116+
- round the value to the closest integer.
116117
117118
Args:
118119
argmax: whether to execute argmax function on input data before transform.
@@ -125,6 +126,8 @@ class AsDiscrete(Transform):
125126
Defaults to ``False``.
126127
logit_thresh: the threshold value for thresholding operation..
127128
Defaults to ``0.5``.
129+
rounding: if not None, round the data according to the specified option,
130+
available options: ["torchrounding"].
128131
129132
"""
130133

@@ -135,12 +138,14 @@ def __init__(
135138
n_classes: Optional[int] = None,
136139
threshold_values: bool = False,
137140
logit_thresh: float = 0.5,
141+
rounding: Optional[str] = None,
138142
) -> None:
139143
self.argmax = argmax
140144
self.to_onehot = to_onehot
141145
self.n_classes = n_classes
142146
self.threshold_values = threshold_values
143147
self.logit_thresh = logit_thresh
148+
self.rounding = rounding
144149

145150
def __call__(
146151
self,
@@ -150,6 +155,7 @@ def __call__(
150155
n_classes: Optional[int] = None,
151156
threshold_values: Optional[bool] = None,
152157
logit_thresh: Optional[float] = None,
158+
rounding: Optional[str] = None,
153159
) -> torch.Tensor:
154160
"""
155161
Args:
@@ -165,6 +171,8 @@ def __call__(
165171
Defaults to ``self.threshold_values``.
166172
logit_thresh: the threshold value for thresholding operation..
167173
Defaults to ``self.logit_thresh``.
174+
rounding: if not None, round the data according to the specified option,
175+
available options: ["torchrounding"].
168176
169177
"""
170178
if argmax or self.argmax:
@@ -179,6 +187,11 @@ def __call__(
179187
if threshold_values or self.threshold_values:
180188
img = img >= (self.logit_thresh if logit_thresh is None else logit_thresh)
181189

190+
rounding = self.rounding if rounding is None else rounding
191+
if rounding is not None:
192+
rounding = look_up_option(rounding, ["torchrounding"])
193+
img = torch.round(img)
194+
182195
return img.float()
183196

184197

monai/transforms/post/dictionary.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
n_classes: Optional[Union[Sequence[int], int]] = None,
135135
threshold_values: Union[Sequence[bool], bool] = False,
136136
logit_thresh: Union[Sequence[float], float] = 0.5,
137+
rounding: Union[Sequence[Optional[str]], Optional[str]] = None,
137138
allow_missing_keys: bool = False,
138139
) -> None:
139140
"""
@@ -150,6 +151,9 @@ def __init__(
150151
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
151152
logit_thresh: the threshold value for thresholding operation, default is 0.5.
152153
it also can be a sequence of float, each element corresponds to a key in ``keys``.
154+
rounding: if not None, round the data according to the specified option,
155+
available options: ["torchrounding"]. it also can be a sequence of str or None,
156+
each element corresponds to a key in ``keys``.
153157
allow_missing_keys: don't raise exception if key is missing.
154158
155159
"""
@@ -159,12 +163,13 @@ def __init__(
159163
self.n_classes = ensure_tuple_rep(n_classes, len(self.keys))
160164
self.threshold_values = ensure_tuple_rep(threshold_values, len(self.keys))
161165
self.logit_thresh = ensure_tuple_rep(logit_thresh, len(self.keys))
166+
self.rounding = ensure_tuple_rep(rounding, len(self.keys))
162167
self.converter = AsDiscrete()
163168

164169
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
165170
d = dict(data)
166-
for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh in self.key_iterator(
167-
d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh
171+
for key, argmax, to_onehot, n_classes, threshold_values, logit_thresh, rounding in self.key_iterator(
172+
d, self.argmax, self.to_onehot, self.n_classes, self.threshold_values, self.logit_thresh, self.rounding
168173
):
169174
d[key] = self.converter(
170175
d[key],
@@ -173,6 +178,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
173178
n_classes,
174179
threshold_values,
175180
logit_thresh,
181+
rounding,
176182
)
177183
return d
178184

tests/test_as_discrete.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,16 @@
4444
(3,),
4545
]
4646

47+
TEST_CASE_5 = [
48+
{"rounding": "torchrounding"},
49+
torch.tensor([[[0.123, 1.345], [2.567, 3.789]]]),
50+
torch.tensor([[[0.0, 1.0], [3.0, 4.0]]]),
51+
(1, 2, 2),
52+
]
53+
4754

4855
class TestAsDiscrete(unittest.TestCase):
49-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
56+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
5057
def test_value_shape(self, input_param, img, out, expected_shape):
5158
result = AsDiscrete(**input_param)(img)
5259
torch.testing.assert_allclose(result, out)

tests/test_as_discreted.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,16 @@
5858
(2, 1, 2),
5959
]
6060

61+
TEST_CASE_4 = [
62+
{"keys": "pred", "rounding": "torchrounding"},
63+
{"pred": torch.tensor([[[0.123, 1.345], [2.567, 3.789]]])},
64+
{"pred": torch.tensor([[[0.0, 1.0], [3.0, 4.0]]])},
65+
(1, 2, 2),
66+
]
67+
6168

6269
class TestAsDiscreted(unittest.TestCase):
63-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
70+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
6471
def test_value_shape(self, input_param, test_input, output, expected_shape):
6572
result = AsDiscreted(**input_param)(test_input)
6673
torch.testing.assert_allclose(result["pred"], output["pred"])

0 commit comments

Comments
 (0)