Skip to content

Commit ac97a0c

Browse files
committed
[DLMED] update dict transform
Signed-off-by: Nic Ma <[email protected]>
1 parent 0eb5895 commit ac97a0c

File tree

3 files changed

+18
-5
lines changed

3 files changed

+18
-5
lines changed

monai/transforms/intensity/dictionary.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,7 +1431,8 @@ class RandCoarseDropoutd(RandomizableTransform, MapTransform):
14311431
if some components of the `spatial_size` are non-positive values, the transform will use the
14321432
corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
14331433
to `(32, 64)` if the second spatial dimension size of img is `64`.
1434-
fill_value: target value to fill the dropout regions.
1434+
fill_value: target value to fill the dropout regions, if providing a tuple for the `min` and `max`,
1435+
will randomly select value for every pixel / voxel from the range `[min, max)`.
14351436
max_holes: if not None, define the maximum number to randomly select the expected number of regions.
14361437
max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
14371438
if some components of the `max_spatial_size` are non-positive values, the transform will use the
@@ -1447,7 +1448,7 @@ def __init__(
14471448
keys: KeysCollection,
14481449
holes: int,
14491450
spatial_size: Union[Sequence[int], int],
1450-
fill_value: Union[float, int] = 0,
1451+
fill_value: Union[Tuple[Union[float, int]], Union[float, int]] = 0,
14511452
max_holes: Optional[int] = None,
14521453
max_spatial_size: Optional[Union[Sequence[int], int]] = None,
14531454
prob: float = 0.1,
@@ -1483,7 +1484,12 @@ def __call__(self, data):
14831484
if self._do_transform:
14841485
for key in self.key_iterator(d):
14851486
for h in self.hole_coords:
1486-
d[key][h] = self.fill_value
1487+
if isinstance(self.fill_value, (tuple, list)):
1488+
if len(self.fill_value) != 2:
1489+
raise ValueError("fill_value should contain 2 numbers if providing the `min` and `max`.")
1490+
d[key][h] = self.R.uniform(self.fill_value[0], self.fill_value[1], size=d[key][h].shape)
1491+
else:
1492+
d[key][h] = self.fill_value
14871493
return d
14881494

14891495

tests/test_rand_coarse_dropout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_value(self, input_param, input_data):
6565
if isinstance(fill_value, (int, float)):
6666
np.testing.assert_allclose(data, fill_value)
6767
else:
68-
print("hole data:", data)
6968
min_value = data.min()
7069
max_value = data.max()
7170
self.assertGreaterEqual(max_value, min_value)

tests/test_rand_coarse_dropoutd.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,15 @@ def test_value(self, input_param, input_data):
7070

7171
for h in dropout.hole_coords:
7272
data = result[h]
73-
np.testing.assert_allclose(data, input_param.get("fill_value", 0))
73+
fill_value = input_param.get("fill_value", 0)
74+
if isinstance(fill_value, (int, float)):
75+
np.testing.assert_allclose(data, fill_value)
76+
else:
77+
min_value = data.min()
78+
max_value = data.max()
79+
self.assertGreaterEqual(max_value, min_value)
80+
self.assertGreaterEqual(min_value, fill_value[0])
81+
self.assertLess(max_value, fill_value[1])
7482
if max_spatial_size is None:
7583
self.assertTupleEqual(data.shape[1:], tuple(spatial_size))
7684
else:

0 commit comments

Comments
 (0)