Skip to content

Commit c4833f5

Browse files
Nic-Mawyli
andauthored
2648 add RandCoarseDropout transform (#2658)
* [DLMED] add RandCoarseDropout Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix typo Signed-off-by: Nic Ma <[email protected]> * [DLMED] add dict version transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] updated according to comments Signed-off-by: Nic Ma <[email protected]> Co-authored-by: Wenqi Li <[email protected]>
1 parent aaeebd6 commit c4833f5

File tree

8 files changed

+342
-19
lines changed

8 files changed

+342
-19
lines changed

docs/source/transforms.rst

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,41 +274,47 @@ Intensity
274274
:special-members: __call__
275275

276276
`RandHistogramShift`
277-
"""""""""""""""""""""
277+
""""""""""""""""""""
278278
.. autoclass:: RandHistogramShift
279279
:members:
280280
:special-members: __call__
281281

282282
`DetectEnvelope`
283-
"""""""""""""""""""""
283+
""""""""""""""""
284284
.. autoclass:: DetectEnvelope
285285
:members:
286286
:special-members: __call__
287287

288288
`GibbsNoise`
289-
""""""""""""""
289+
""""""""""""
290290
.. autoclass:: GibbsNoise
291291
:members:
292292
:special-members: __call__
293293

294294
`RandGibbsNoise`
295-
"""""""""""""""""
295+
""""""""""""""""
296296
.. autoclass:: RandGibbsNoise
297297
:members:
298298
:special-members: __call__
299299

300300
`KSpaceSpikeNoise`
301-
""""""""""""""""""""
301+
""""""""""""""""""
302302
.. autoclass:: KSpaceSpikeNoise
303303
:members:
304304
:special-members: __call__
305305

306306
`RandKSpaceSpikeNoise`
307-
""""""""""""""""""""""""
307+
""""""""""""""""""""""
308308
.. autoclass:: RandKSpaceSpikeNoise
309309
:members:
310310
:special-members: __call__
311311

312+
`RandCoarseDropout`
313+
"""""""""""""""""""
314+
.. autoclass:: RandCoarseDropout
315+
:members:
316+
:special-members: __call__
317+
312318

313319
IO
314320
^^
@@ -889,6 +895,12 @@ Intensity (Dict)
889895
:members:
890896
:special-members: __call__
891897

898+
`RandCoarseDropoutd`
899+
""""""""""""""""""""
900+
.. autoclass:: RandCoarseDropoutd
901+
:members:
902+
:special-members: __call__
903+
892904
IO (Dict)
893905
^^^^^^^^^
894906

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
NormalizeIntensity,
8989
RandAdjustContrast,
9090
RandBiasField,
91+
RandCoarseDropout,
9192
RandGaussianNoise,
9293
RandGaussianSharpen,
9394
RandGaussianSmooth,
@@ -134,6 +135,9 @@
134135
RandBiasFieldd,
135136
RandBiasFieldD,
136137
RandBiasFieldDict,
138+
RandCoarseDropoutd,
139+
RandCoarseDropoutD,
140+
RandCoarseDropoutDict,
137141
RandGaussianNoised,
138142
RandGaussianNoiseD,
139143
RandGaussianNoiseDict,

monai/transforms/intensity/array.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222

2323
from monai.config import DtypeLike
24+
from monai.data.utils import get_random_patch, get_valid_patch_size
2425
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
2526
from monai.transforms.transform import RandomizableTransform, Transform
2627
from monai.transforms.utils import rescale_array
@@ -31,6 +32,7 @@
3132
ensure_tuple,
3233
ensure_tuple_rep,
3334
ensure_tuple_size,
35+
fall_back_tuple,
3436
)
3537

3638
__all__ = [
@@ -61,6 +63,7 @@
6163
"RandGibbsNoise",
6264
"KSpaceSpikeNoise",
6365
"RandKSpaceSpikeNoise",
66+
"RandCoarseDropout",
6467
]
6568

6669

@@ -1603,3 +1606,68 @@ def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, t
16031606
return img.cpu().detach().numpy(), img.device
16041607
else:
16051608
return img, torch.device("cpu")
1609+
1610+
1611+
class RandCoarseDropout(RandomizableTransform):
1612+
"""
1613+
Randomly coarse dropout regions in the image, then fill in the rectangular regions with specified value.
1614+
Refer to: https://arxiv.org/abs/1708.04552 and:
1615+
https://albumentations.ai/docs/api_reference/augmentations/transforms/
1616+
#albumentations.augmentations.transforms.CoarseDropout.
1617+
1618+
Args:
1619+
holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
1620+
randomly select the expected number of regions.
1621+
spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
1622+
as the minimum spatial size to randomly select size for every region.
1623+
if some components of the `spatial_size` are non-positive values, the transform will use the
1624+
corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
1625+
to `(32, 64)` if the second spatial dimension size of img is `64`.
1626+
fill_value: target value to fill the dropout regions.
1627+
max_holes: if not None, define the maximum number to randomly select the expected number of regions.
1628+
max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
1629+
if some components of the `max_spatial_size` are non-positive values, the transform will use the
1630+
corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
1631+
to `(32, 64)` if the second spatial dimension size of img is `64`.
1632+
prob: probability of applying the transform.
1633+
1634+
"""
1635+
1636+
def __init__(
1637+
self,
1638+
holes: int,
1639+
spatial_size: Union[Sequence[int], int],
1640+
fill_value: Union[float, int] = 0,
1641+
max_holes: Optional[int] = None,
1642+
max_spatial_size: Optional[Union[Sequence[int], int]] = None,
1643+
prob: float = 0.1,
1644+
) -> None:
1645+
RandomizableTransform.__init__(self, prob)
1646+
if holes < 1:
1647+
raise ValueError("number of holes must be greater than 0.")
1648+
self.holes = holes
1649+
self.spatial_size = spatial_size
1650+
self.fill_value = fill_value
1651+
self.max_holes = max_holes
1652+
self.max_spatial_size = max_spatial_size
1653+
self.hole_coords: List = []
1654+
1655+
def randomize(self, img_size: Sequence[int]) -> None:
1656+
super().randomize(None)
1657+
size = fall_back_tuple(self.spatial_size, img_size)
1658+
self.hole_coords = [] # clear previously computed coords
1659+
num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1)
1660+
for _ in range(num_holes):
1661+
if self.max_spatial_size is not None:
1662+
max_size = fall_back_tuple(self.max_spatial_size, img_size)
1663+
size = tuple(self.R.randint(low=size[i], high=max_size[i] + 1) for i in range(len(img_size)))
1664+
valid_size = get_valid_patch_size(img_size, size)
1665+
self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R))
1666+
1667+
def __call__(self, img: np.ndarray):
1668+
self.randomize(img.shape[1:])
1669+
if self._do_transform:
1670+
for h in self.hole_coords:
1671+
img[h] = self.fill_value
1672+
1673+
return img

monai/transforms/intensity/dictionary.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch
2323

2424
from monai.config import DtypeLike, KeysCollection
25+
from monai.data.utils import get_random_patch, get_valid_patch_size
2526
from monai.transforms.intensity.array import (
2627
AdjustContrast,
2728
GaussianSharpen,
@@ -41,7 +42,7 @@
4142
ThresholdIntensity,
4243
)
4344
from monai.transforms.transform import MapTransform, RandomizableTransform
44-
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size
45+
from monai.utils import dtype_torch_to_numpy, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
4546

4647
__all__ = [
4748
"RandGaussianNoised",
@@ -69,6 +70,7 @@
6970
"KSpaceSpikeNoised",
7071
"RandKSpaceSpikeNoised",
7172
"RandHistogramShiftd",
73+
"RandCoarseDropoutd",
7274
"RandGaussianNoiseD",
7375
"RandGaussianNoiseDict",
7476
"ShiftIntensityD",
@@ -117,13 +119,16 @@
117119
"RandHistogramShiftDict",
118120
"RandRicianNoiseD",
119121
"RandRicianNoiseDict",
122+
"RandCoarseDropoutD",
123+
"RandCoarseDropoutDict",
120124
]
121125

122126

123127
class RandGaussianNoised(RandomizableTransform, MapTransform):
124128
"""
125129
Dictionary-based version :py:class:`monai.transforms.RandGaussianNoise`.
126-
Add Gaussian noise to image. This transform assumes all the expected fields have same shape.
130+
Add Gaussian noise to image. This transform assumes all the expected fields have same shape, if want to add
131+
different noise for every field, please use this transform separately.
127132
128133
Args:
129134
keys: keys of the corresponding items to be transformed.
@@ -172,7 +177,8 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
172177
class RandRicianNoised(RandomizableTransform, MapTransform):
173178
"""
174179
Dictionary-based version :py:class:`monai.transforms.RandRicianNoise`.
175-
Add Rician noise to image. This transform assumes all the expected fields have same shape.
180+
Add Rician noise to image. This transform assumes all the expected fields have same shape, if want to add
181+
different noise for every field, please use this transform separately.
176182
177183
Args:
178184
keys: Keys of the corresponding items to be transformed.
@@ -1324,6 +1330,78 @@ def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
13241330
return d_numpy
13251331

13261332

1333+
class RandCoarseDropoutd(RandomizableTransform, MapTransform):
1334+
"""
1335+
Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseDropout`.
1336+
Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions
1337+
for every key, if want to dropout differently for every key, please use this transform separately.
1338+
1339+
Args:
1340+
keys: keys of the corresponding items to be transformed.
1341+
See also: :py:class:`monai.transforms.compose.MapTransform`
1342+
holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
1343+
randomly select the expected number of regions.
1344+
spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
1345+
as the minimum spatial size to randomly select size for every region.
1346+
if some components of the `spatial_size` are non-positive values, the transform will use the
1347+
corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
1348+
to `(32, 64)` if the second spatial dimension size of img is `64`.
1349+
fill_value: target value to fill the dropout regions.
1350+
max_holes: if not None, define the maximum number to randomly select the expected number of regions.
1351+
max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
1352+
if some components of the `max_spatial_size` are non-positive values, the transform will use the
1353+
corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
1354+
to `(32, 64)` if the second spatial dimension size of img is `64`.
1355+
prob: probability of applying the transform.
1356+
allow_missing_keys: don't raise exception if key is missing.
1357+
1358+
"""
1359+
1360+
def __init__(
1361+
self,
1362+
keys: KeysCollection,
1363+
holes: int,
1364+
spatial_size: Union[Sequence[int], int],
1365+
fill_value: Union[float, int] = 0,
1366+
max_holes: Optional[int] = None,
1367+
max_spatial_size: Optional[Union[Sequence[int], int]] = None,
1368+
prob: float = 0.1,
1369+
allow_missing_keys: bool = False,
1370+
):
1371+
MapTransform.__init__(self, keys, allow_missing_keys)
1372+
RandomizableTransform.__init__(self, prob)
1373+
if holes < 1:
1374+
raise ValueError("number of holes must be greater than 0.")
1375+
self.holes = holes
1376+
self.spatial_size = spatial_size
1377+
self.fill_value = fill_value
1378+
self.max_holes = max_holes
1379+
self.max_spatial_size = max_spatial_size
1380+
self.hole_coords: List = []
1381+
1382+
def randomize(self, img_size: Sequence[int]) -> None:
1383+
super().randomize(None)
1384+
size = fall_back_tuple(self.spatial_size, img_size)
1385+
self.hole_coords = [] # clear previously computed coords
1386+
num_holes = self.holes if self.max_holes is None else self.R.randint(self.holes, self.max_holes + 1)
1387+
for _ in range(num_holes):
1388+
if self.max_spatial_size is not None:
1389+
max_size = fall_back_tuple(self.max_spatial_size, img_size)
1390+
size = tuple(self.R.randint(low=size[i], high=max_size[i] + 1) for i in range(len(img_size)))
1391+
valid_size = get_valid_patch_size(img_size, size)
1392+
self.hole_coords.append((slice(None),) + get_random_patch(img_size, valid_size, self.R))
1393+
1394+
def __call__(self, data):
1395+
d = dict(data)
1396+
# expect all the specified keys have same spatial shape
1397+
self.randomize(d[self.keys[0]].shape[1:])
1398+
if self._do_transform:
1399+
for key in self.key_iterator(d):
1400+
for h in self.hole_coords:
1401+
d[key][h] = self.fill_value
1402+
return d
1403+
1404+
13271405
RandGaussianNoiseD = RandGaussianNoiseDict = RandGaussianNoised
13281406
RandRicianNoiseD = RandRicianNoiseDict = RandRicianNoised
13291407
ShiftIntensityD = ShiftIntensityDict = ShiftIntensityd
@@ -1349,3 +1427,4 @@ def _to_numpy(self, d: Union[torch.Tensor, np.ndarray]) -> np.ndarray:
13491427
GibbsNoiseD = GibbsNoiseDict = GibbsNoised
13501428
KSpaceSpikeNoiseD = KSpaceSpikeNoiseDict = KSpaceSpikeNoised
13511429
RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised
1430+
RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd

monai/transforms/spatial/array.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ class Resize(Transform):
337337
338338
Args:
339339
spatial_size: expected shape of spatial dimensions after resize operation.
340-
if the components of the `spatial_size` are non-positive values, the transform will use the
340+
if some components of the `spatial_size` are non-positive values, the transform will use the
341341
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
342342
to `(32, 64)` if the second spatial dimension size of img is `64`.
343343
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
@@ -1297,7 +1297,7 @@ def __init__(
12971297
spatial_size: output image spatial size.
12981298
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
12991299
the transform will use the spatial size of `img`.
1300-
if the components of the `spatial_size` are non-positive values, the transform will use the
1300+
if some components of the `spatial_size` are non-positive values, the transform will use the
13011301
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
13021302
to `(32, 64)` if the second spatial dimension size of img is `64`.
13031303
mode: {``"bilinear"``, ``"nearest"``}
@@ -1390,7 +1390,7 @@ def __init__(
13901390
spatial_size: output image spatial size.
13911391
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
13921392
the transform will use the spatial size of `img`.
1393-
if the components of the `spatial_size` are non-positive values, the transform will use the
1393+
if some components of the `spatial_size` are non-positive values, the transform will use the
13941394
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
13951395
to `(32, 64)` if the second spatial dimension size of img is `64`.
13961396
mode: {``"bilinear"``, ``"nearest"``}
@@ -1553,7 +1553,7 @@ def __init__(
15531553
spatial_size: specifying output image spatial size [h, w].
15541554
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
15551555
the transform will use the spatial size of `img`.
1556-
if the components of the `spatial_size` are non-positive values, the transform will use the
1556+
if some components of the `spatial_size` are non-positive values, the transform will use the
15571557
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
15581558
to `(32, 64)` if the second spatial dimension size of img is `64`.
15591559
mode: {``"bilinear"``, ``"nearest"``}
@@ -1681,7 +1681,7 @@ def __init__(
16811681
spatial_size: specifying output image spatial size [h, w, d].
16821682
if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
16831683
the transform will use the spatial size of `img`.
1684-
if the components of the `spatial_size` are non-positive values, the transform will use the
1684+
if some components of the `spatial_size` are non-positive values, the transform will use the
16851685
corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted
16861686
to `(32, 32, 64)` if the third spatial dimension size of img is `64`.
16871687
mode: {``"bilinear"``, ``"nearest"``}

0 commit comments

Comments
 (0)