Skip to content

Commit 945e21c

Browse files
SpenhouetSebastian Penhouet
andauthored
[2678] Add transform to fill holes and to filter (#2692)
* Add transform to fill holes and to filter (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * Change name of label filter class (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * Change fill holes to growing logic (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * Fix missing entry in min_tests (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * Fix review comments (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * Remove batch dim and add one-hot handling (#2678) Signed-off-by: Sebastian Penhouet <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> Co-authored-by: Sebastian Penhouet <[email protected]>
1 parent 62425d7 commit 945e21c

File tree

10 files changed

+864
-98
lines changed

10 files changed

+864
-98
lines changed

docs/source/transforms.rst

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,18 @@ Post-processing
356356
:members:
357357
:special-members: __call__
358358

359+
`LabelFilter`
360+
"""""""""""""
361+
.. autoclass:: LabelFilter
362+
:members:
363+
:special-members: __call__
364+
365+
`FillHoles`
366+
"""""""""""
367+
.. autoclass:: FillHoles
368+
:members:
369+
:special-members: __call__
370+
359371
`LabelToContour`
360372
""""""""""""""""
361373
.. autoclass:: LabelToContour
@@ -955,6 +967,18 @@ Post-processing (Dict)
955967
:members:
956968
:special-members: __call__
957969

970+
`LabelFilterd`
971+
""""""""""""""
972+
.. autoclass:: LabelFilterd
973+
:members:
974+
:special-members: __call__
975+
976+
`FillHolesd`
977+
""""""""""""
978+
.. autoclass:: FillHolesd
979+
:members:
980+
:special-members: __call__
981+
958982
`LabelToContourd`
959983
"""""""""""""""""
960984
.. autoclass:: LabelToContourd

monai/transforms/__init__.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -194,40 +194,48 @@
194194
from .post.array import (
195195
Activations,
196196
AsDiscrete,
197+
FillHoles,
197198
KeepLargestConnectedComponent,
199+
LabelFilter,
198200
LabelToContour,
199201
MeanEnsemble,
200202
ProbNMS,
201203
VoteEnsemble,
202204
)
203205
from .post.dictionary import (
204-
Activationsd,
205206
ActivationsD,
207+
Activationsd,
206208
ActivationsDict,
207-
AsDiscreted,
208209
AsDiscreteD,
210+
AsDiscreted,
209211
AsDiscreteDict,
210212
Ensembled,
211-
Invertd,
213+
FillHolesD,
214+
FillHolesd,
215+
FillHolesDict,
212216
InvertD,
217+
Invertd,
213218
InvertDict,
214-
KeepLargestConnectedComponentd,
215219
KeepLargestConnectedComponentD,
220+
KeepLargestConnectedComponentd,
216221
KeepLargestConnectedComponentDict,
217-
LabelToContourd,
222+
LabelFilterD,
223+
LabelFilterd,
224+
LabelFilterDict,
218225
LabelToContourD,
226+
LabelToContourd,
219227
LabelToContourDict,
220-
MeanEnsembled,
221228
MeanEnsembleD,
229+
MeanEnsembled,
222230
MeanEnsembleDict,
223-
ProbNMSd,
224231
ProbNMSD,
232+
ProbNMSd,
225233
ProbNMSDict,
226-
SaveClassificationd,
227234
SaveClassificationD,
235+
SaveClassificationd,
228236
SaveClassificationDict,
229-
VoteEnsembled,
230237
VoteEnsembleD,
238+
VoteEnsembled,
231239
VoteEnsembleDict,
232240
)
233241
from .spatial.array import (

monai/transforms/post/array.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,29 @@
1414
"""
1515

1616
import warnings
17-
from typing import Callable, Optional, Sequence, Union
17+
from typing import Callable, Iterable, Optional, Sequence, Union
1818

1919
import numpy as np
2020
import torch
2121
import torch.nn.functional as F
2222

23+
from monai.config import NdarrayTensor
2324
from monai.networks import one_hot
2425
from monai.networks.layers import GaussianFilter
2526
from monai.transforms.transform import Transform
26-
from monai.transforms.utils import get_largest_connected_component_mask
27+
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
2728
from monai.utils import ensure_tuple
2829

2930
__all__ = [
3031
"Activations",
3132
"AsDiscrete",
33+
"FillHoles",
3234
"KeepLargestConnectedComponent",
35+
"LabelFilter",
3336
"LabelToContour",
3437
"MeanEnsemble",
35-
"VoteEnsemble",
3638
"ProbNMS",
39+
"VoteEnsemble",
3740
]
3841

3942

@@ -289,6 +292,137 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
289292
return output
290293

291294

295+
class LabelFilter:
296+
"""
297+
This transform filters out labels and can be used as a processing step to view only certain labels.
298+
299+
The list of applied labels defines which labels will be kept.
300+
301+
Note:
302+
All labels which do not match the `applied_labels` are set to the background label (0).
303+
304+
For example:
305+
306+
Use LabelFilter with applied_labels=[1, 5, 9]::
307+
308+
[1, 2, 3] [1, 0, 0]
309+
[4, 5, 6] => [0, 5 ,0]
310+
[7, 8, 9] [0, 0, 9]
311+
"""
312+
313+
def __init__(self, applied_labels: Union[Iterable[int], int]) -> None:
314+
"""
315+
Initialize the LabelFilter class with the labels to filter on.
316+
317+
Args:
318+
applied_labels: Label(s) to filter on.
319+
"""
320+
self.applied_labels = ensure_tuple(applied_labels)
321+
322+
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
323+
"""
324+
Filter the image on the `applied_labels`.
325+
326+
Args:
327+
img: Pytorch tensor or numpy array of any shape.
328+
329+
Raises:
330+
NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.
331+
332+
Returns:
333+
Pytorch tensor or numpy array of the same shape as the input.
334+
"""
335+
if isinstance(img, np.ndarray):
336+
return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
337+
elif isinstance(img, torch.Tensor):
338+
img_arr = img.detach().cpu().numpy()
339+
img_arr = self(img_arr)
340+
return torch.as_tensor(img_arr, device=img.device)
341+
else:
342+
raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.")
343+
344+
345+
class FillHoles(Transform):
346+
r"""
347+
This transform fills holes in the image and can be used to remove artifacts inside segments.
348+
349+
An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class.
350+
The definition of enclosed can be defined with the connectivity parameter::
351+
352+
1-connectivity 2-connectivity diagonal connection close-up
353+
354+
[ ] [ ] [ ] [ ] [ ]
355+
| \ | / | <- hop 2
356+
[ ]--[x]--[ ] [ ]--[x]--[ ] [x]--[ ]
357+
| / | \ hop 1
358+
[ ] [ ] [ ] [ ]
359+
360+
It is possible to define for which labels the hole filling should be applied.
361+
The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]].
362+
If C = 1, then the values correspond to expected labels.
363+
If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing.
364+
365+
Note:
366+
367+
The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.
368+
369+
The performance of this method heavily depends on the number of labels.
370+
It is a bit faster if the list of `applied_labels` is provided.
371+
Limiting the number of `applied_labels` results in a big decrease in processing time.
372+
373+
For example:
374+
375+
Use FillHoles with default parameters::
376+
377+
[1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3]
378+
[1, 0, 1, 2, 0, 0, 3, 0] => [1, 1 ,1, 2, 0, 0, 3, 0]
379+
[1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3]
380+
381+
The hole in label 1 is fully enclosed and therefore filled with label 1.
382+
The background label near label 2 and 3 is not fully enclosed and therefore not filled.
383+
"""
384+
385+
def __init__(
386+
self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None
387+
) -> None:
388+
"""
389+
Initialize the connectivity and limit the labels for which holes are filled.
390+
391+
Args:
392+
applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels.
393+
connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
394+
Accepted values are ranging from 1 to input.ndim. Defaults to a full connectivity of ``input.ndim``.
395+
"""
396+
super().__init__()
397+
self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None
398+
self.connectivity = connectivity
399+
400+
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
401+
"""
402+
Fill the holes in the provided image.
403+
404+
Note:
405+
The value 0 is assumed as background label.
406+
407+
Args:
408+
img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
409+
410+
Raises:
411+
NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.
412+
413+
Returns:
414+
Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
415+
"""
416+
if isinstance(img, np.ndarray):
417+
return fill_holes(img, self.applied_labels, self.connectivity)
418+
elif isinstance(img, torch.Tensor):
419+
img_arr = img.detach().cpu().numpy()
420+
img_arr = self(img_arr)
421+
return torch.as_tensor(img_arr, device=img.device)
422+
else:
423+
raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.")
424+
425+
292426
class LabelToContour(Transform):
293427
"""
294428
Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel

0 commit comments

Comments
 (0)