Skip to content

Commit 410109a

Browse files
Can-ZhaoKumoLiupre-commit-ci[bot]
authored
Maisi morphological funcs (#7893)
Fixes # . ### Description Maisi morphological funcs ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: Can Zhao <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 15d0771 commit 410109a

File tree

4 files changed

+290
-0
lines changed

4 files changed

+290
-0
lines changed

docs/source/apps.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,11 @@ FastMRIReader
261261

262262
.. autoclass:: monai.apps.nnunet.nnUNetV2Runner
263263
:members:
264+
265+
`Generative AI`
266+
---------------
267+
268+
`MAISI Utilities`
269+
~~~~~~~~~~~~~~~~~
270+
.. automodule:: monai.apps.generation.maisi.utils.morphological_ops
271+
:members:
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import Sequence
15+
16+
import torch
17+
import torch.nn.functional as F
18+
from torch import Tensor
19+
20+
from monai.config import NdarrayOrTensor
21+
from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep
22+
23+
24+
def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
25+
"""
26+
Erode 2D/3D binary mask.
27+
28+
Args:
29+
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
30+
filter_size: erosion filter size, has to be odd numbers, default to be 3.
31+
pad_value: the filled value for padding. We need to pad the input before filtering
32+
to keep the output with the same size as input. Usually use default value
33+
and not changed.
34+
35+
Return:
36+
eroded mask, same shape and data type as input.
37+
38+
Example:
39+
40+
.. code-block:: python
41+
42+
# define a naive mask
43+
mask = torch.zeros(3,2,3,3,3)
44+
mask[:,:,1,1,1] = 1.0
45+
filter_size = 3
46+
erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3)
47+
dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3)
48+
"""
49+
mask_t, *_ = convert_data_type(mask, torch.Tensor)
50+
res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)
51+
res_mask: NdarrayOrTensor
52+
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
53+
return res_mask
54+
55+
56+
def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:
57+
"""
58+
Dilate 2D/3D binary mask.
59+
60+
Args:
61+
mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
62+
filter_size: dilation filter size, has to be odd numbers, default to be 3.
63+
pad_value: the filled value for padding. We need to pad the input before filtering
64+
to keep the output with the same size as input. Usually use default value
65+
and not changed.
66+
67+
Return:
68+
dilated mask, same shape and data type as input.
69+
70+
Example:
71+
72+
.. code-block:: python
73+
74+
# define a naive mask
75+
mask = torch.zeros(3,2,3,3,3)
76+
mask[:,:,1,1,1] = 1.0
77+
filter_size = 3
78+
erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
79+
dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
80+
"""
81+
mask_t, *_ = convert_data_type(mask, torch.Tensor)
82+
res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)
83+
res_mask: NdarrayOrTensor
84+
res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
85+
return res_mask
86+
87+
88+
def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:
89+
"""
90+
Apply a morphological filter to a 2D/3D binary mask tensor.
91+
92+
Args:
93+
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
94+
filter_size: morphological filter size, has to be odd numbers.
95+
pad_value: the filled value for padding. We need to pad the input before filtering
96+
to keep the output with the same size as input.
97+
98+
Return:
99+
Tensor: Morphological filter result mask, same shape as input.
100+
"""
101+
spatial_dims = len(mask_t.shape) - 2
102+
if spatial_dims not in [2, 3]:
103+
raise ValueError(
104+
f"spatial_dims must be either 2 or 3, "
105+
f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}."
106+
)
107+
108+
# Define the structuring element
109+
filter_size = ensure_tuple_rep(filter_size, spatial_dims)
110+
if any(size % 2 == 0 for size in filter_size):
111+
raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.")
112+
113+
structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device)
114+
115+
# Pad the input tensor to handle border pixels
116+
# Calculate padding size
117+
pad_size = [size // 2 for size in filter_size for _ in range(2)]
118+
119+
input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value)
120+
121+
# Apply filter operation
122+
conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d
123+
output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])
124+
125+
return output
126+
127+
128+
def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
129+
"""
130+
Erode 2D/3D binary mask with data type as torch tensor.
131+
132+
Args:
133+
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
134+
filter_size: erosion filter size, has to be odd numbers, default to be 3.
135+
pad_value: the filled value for padding. We need to pad the input before filtering
136+
to keep the output with the same size as input. Usually use default value
137+
and not changed.
138+
139+
Return:
140+
Tensor: eroded mask, same shape as input.
141+
"""
142+
143+
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)
144+
145+
# Set output values based on the minimum value within the structuring element
146+
output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0)
147+
148+
return output
149+
150+
151+
def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
152+
"""
153+
Dilate 2D/3D binary mask with data type as torch tensor.
154+
155+
Args:
156+
mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
157+
filter_size: dilation filter size, has to be odd numbers, default to be 3.
158+
pad_value: the filled value for padding. We need to pad the input before filtering
159+
to keep the output with the same size as input. Usually use default value
160+
and not changed.
161+
162+
Return:
163+
Tensor: dilated mask, same shape as input.
164+
"""
165+
output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)
166+
167+
# Set output values based on the minimum value within the structuring element
168+
output = torch.where(output > 0, 1.0, 0.0)
169+
170+
return output

tests/test_morphological_ops.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import torch
17+
from parameterized import parameterized
18+
19+
from monai.apps.generation.maisi.utils.morphological_ops import dilate, erode, get_morphological_filter_result_t
20+
from tests.utils import TEST_NDARRAYS, assert_allclose
21+
22+
TESTS_SHAPE = []
23+
for p in TEST_NDARRAYS:
24+
mask = torch.zeros(1, 1, 5, 5, 5)
25+
filter_size = 3
26+
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]])
27+
mask = torch.zeros(3, 2, 5, 5, 5)
28+
filter_size = 5
29+
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]])
30+
mask = torch.zeros(1, 1, 1, 1, 1)
31+
filter_size = 5
32+
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]])
33+
mask = torch.zeros(1, 1, 1, 1)
34+
filter_size = 5
35+
TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]])
36+
37+
TESTS_VALUE_T = []
38+
filter_size = 3
39+
mask = torch.ones(3, 2, 3, 3, 3)
40+
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)])
41+
mask = torch.zeros(3, 2, 3, 3, 3)
42+
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)])
43+
mask = torch.ones(3, 2, 3, 3)
44+
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)])
45+
mask = torch.zeros(3, 2, 3, 3)
46+
TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)])
47+
48+
TESTS_VALUE = []
49+
for p in TEST_NDARRAYS:
50+
mask = torch.zeros(3, 2, 5, 5, 5)
51+
filter_size = 3
52+
TESTS_VALUE.append(
53+
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]
54+
)
55+
mask = torch.ones(1, 1, 3, 3, 3)
56+
filter_size = 3
57+
TESTS_VALUE.append(
58+
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]
59+
)
60+
mask = torch.ones(1, 2, 3, 3, 3)
61+
filter_size = 3
62+
TESTS_VALUE.append(
63+
[{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]
64+
)
65+
mask = torch.zeros(3, 2, 3, 3, 3)
66+
mask[:, :, 1, 1, 1] = 1.0
67+
filter_size = 3
68+
TESTS_VALUE.append(
69+
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]
70+
)
71+
mask = torch.zeros(3, 2, 3, 3)
72+
mask[:, :, 1, 1] = 1.0
73+
filter_size = 3
74+
TESTS_VALUE.append(
75+
[{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))]
76+
)
77+
78+
79+
class TestMorph(unittest.TestCase):
80+
81+
@parameterized.expand(TESTS_SHAPE)
82+
def test_shape(self, input_data, expected_result):
83+
result1 = erode(input_data["mask"], input_data["filter_size"])
84+
assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)
85+
86+
@parameterized.expand(TESTS_VALUE_T)
87+
def test_value_t(self, input_data, expected_result):
88+
result1 = get_morphological_filter_result_t(
89+
input_data["mask"], input_data["filter_size"], input_data["pad_value"]
90+
)
91+
assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)
92+
93+
@parameterized.expand(TESTS_VALUE)
94+
def test_value(self, input_data, expected_erode_result, expected_dilate_result):
95+
result1 = erode(input_data["mask"], input_data["filter_size"])
96+
assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
97+
result2 = dilate(input_data["mask"], input_data["filter_size"])
98+
assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)
99+
100+
101+
if __name__ == "__main__":
102+
unittest.main()

0 commit comments

Comments
 (0)