Skip to content

Commit 45c6cf6

Browse files
authored
2122 - fixes subprocess transforms (#2128)
* fixes subprocess transforms Signed-off-by: Wenqi Li <[email protected]> * update based on comments Signed-off-by: Wenqi Li <[email protected]> * add SelectItemsd alias Signed-off-by: Wenqi Li <[email protected]>
1 parent 2db93d8 commit 45c6cf6

File tree

10 files changed

+109
-69
lines changed

10 files changed

+109
-69
lines changed

monai/transforms/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@
354354
RepeatChannelD,
355355
RepeatChannelDict,
356356
SelectItemsd,
357+
SelectItemsD,
358+
SelectItemsDict,
357359
SimulateDelayd,
358360
SimulateDelayD,
359361
SimulateDelayDict,
@@ -395,6 +397,7 @@
395397
img_bounds,
396398
in_bounds,
397399
is_empty,
400+
is_positive,
398401
map_binary_to_indices,
399402
map_spatial_axes,
400403
rand_choice,

monai/transforms/croppad/array.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
compute_divisible_spatial_size,
2727
generate_pos_neg_label_crop_centers,
2828
generate_spatial_bounding_box,
29+
is_positive,
2930
map_binary_to_indices,
3031
weighted_patch_samples,
3132
)
@@ -400,7 +401,14 @@ class CropForeground(Transform):
400401
[0, 1, 3, 2, 0],
401402
[0, 1, 2, 1, 0],
402403
[0, 0, 0, 0, 0]]]) # 1x5x5, single channel 5x5 image
403-
cropper = CropForeground(select_fn=lambda x: x > 1, margin=0)
404+
405+
406+
def threshold_at_one(x):
407+
# threshold at 1
408+
return x > 1
409+
410+
411+
cropper = CropForeground(select_fn=threshold_at_one, margin=0)
404412
print(cropper(image))
405413
[[[2, 1],
406414
[3, 2],
@@ -410,7 +418,7 @@ class CropForeground(Transform):
410418

411419
def __init__(
412420
self,
413-
select_fn: Callable = lambda x: x > 0,
421+
select_fn: Callable = is_positive,
414422
channel_indices: Optional[IndexSelection] = None,
415423
margin: Union[Sequence[int], int] = 0,
416424
return_coords: bool = False,
@@ -725,7 +733,7 @@ class BoundingRect(Transform):
725733
select_fn: function to select expected foreground, default is to select values > 0.
726734
"""
727735

728-
def __init__(self, select_fn: Callable = lambda x: x > 0) -> None:
736+
def __init__(self, select_fn: Callable = is_positive) -> None:
729737
self.select_fn = select_fn
730738

731739
def __call__(self, img: np.ndarray) -> np.ndarray:

monai/transforms/croppad/dictionary.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,12 @@
3737
)
3838
from monai.transforms.inverse import InvertibleTransform
3939
from monai.transforms.transform import MapTransform, Randomizable
40-
from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices, weighted_patch_samples
40+
from monai.transforms.utils import (
41+
generate_pos_neg_label_crop_centers,
42+
is_positive,
43+
map_binary_to_indices,
44+
weighted_patch_samples,
45+
)
4146
from monai.utils import ImageMetaKey as Key
4247
from monai.utils import Method, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
4348
from monai.utils.enums import InverseKeys
@@ -572,7 +577,7 @@ def __init__(
572577
self,
573578
keys: KeysCollection,
574579
source_key: str,
575-
select_fn: Callable = lambda x: x > 0,
580+
select_fn: Callable = is_positive,
576581
channel_indices: Optional[IndexSelection] = None,
577582
margin: int = 0,
578583
k_divisible: Union[Sequence[int], int] = 1,
@@ -948,7 +953,7 @@ def __init__(
948953
self,
949954
keys: KeysCollection,
950955
bbox_key_postfix: str = "bbox",
951-
select_fn: Callable = lambda x: x > 0,
956+
select_fn: Callable = is_positive,
952957
allow_missing_keys: bool = False,
953958
):
954959
super().__init__(keys, allow_missing_keys)

monai/transforms/intensity/dictionary.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@
102102
"RandGaussianSharpenDict",
103103
"RandHistogramShiftD",
104104
"RandHistogramShiftDict",
105+
"RandRicianNoiseD",
106+
"RandRicianNoiseDict",
105107
]
106108

107109

monai/transforms/post/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"LabelToContour",
3434
"MeanEnsemble",
3535
"VoteEnsemble",
36+
"ProbNMS",
3637
]
3738

3839

@@ -74,7 +75,7 @@ def __call__(
7475
softmax: whether to execute softmax function on model output before transform.
7576
Defaults to ``self.softmax``.
7677
other: callable function to execute other activation layers, for example:
77-
`other = lambda x: torch.tanh(x)`. Defaults to ``self.other``.
78+
`other = torch.tanh`. Defaults to ``self.other``.
7879
7980
Raises:
8081
ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.

monai/transforms/post/dictionary.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
"DecollateD",
5858
"DecollateDict",
5959
"Decollated",
60+
"ProbNMSd",
61+
"ProbNMSD",
62+
"ProbNMSDict",
6063
]
6164

6265

@@ -83,7 +86,7 @@ def __init__(
8386
softmax: whether to execute softmax function on model output before transform.
8487
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
8588
other: callable function to execute other activation layers,
86-
for example: `other = lambda x: torch.tanh(x)`. it also can be a sequence of Callable, each
89+
for example: `other = torch.tanh`. it also can be a sequence of Callable, each
8790
element corresponds to a key in ``keys``.
8891
allow_missing_keys: don't raise exception if key is missing.
8992

monai/transforms/utility/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
"CastToType",
4242
"ToTensor",
4343
"ToNumpy",
44+
"ToPIL",
4445
"Transpose",
4546
"SqueezeDim",
4647
"DataStats",

monai/transforms/utility/dictionary.py

Lines changed: 67 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -53,84 +53,90 @@
5353
from monai.utils import ensure_tuple, ensure_tuple_rep
5454

5555
__all__ = [
56-
"Identityd",
57-
"AsChannelFirstd",
58-
"AsChannelLastd",
56+
"AddChannelD",
57+
"AddChannelDict",
5958
"AddChanneld",
60-
"EnsureChannelFirstd",
61-
"RepeatChanneld",
62-
"RemoveRepeatedChanneld",
63-
"SplitChanneld",
64-
"CastToTyped",
65-
"ToTensord",
66-
"ToNumpyd",
67-
"ToPILd",
68-
"DeleteItemsd",
69-
"SelectItemsd",
70-
"SqueezeDimd",
71-
"DataStatsd",
72-
"SimulateDelayd",
73-
"CopyItemsd",
74-
"ConcatItemsd",
75-
"Lambdad",
76-
"RandLambdad",
77-
"LabelToMaskd",
78-
"FgBgToIndicesd",
79-
"ConvertToMultiChannelBasedOnBratsClassesd",
59+
"AddExtremePointsChannelD",
60+
"AddExtremePointsChannelDict",
8061
"AddExtremePointsChanneld",
81-
"TorchVisiond",
82-
"RandTorchVisiond",
83-
"MapLabelValued",
84-
"IdentityD",
85-
"IdentityDict",
8662
"AsChannelFirstD",
8763
"AsChannelFirstDict",
64+
"AsChannelFirstd",
8865
"AsChannelLastD",
8966
"AsChannelLastDict",
90-
"AddChannelD",
91-
"AddChannelDict",
67+
"AsChannelLastd",
68+
"CastToTypeD",
69+
"CastToTypeDict",
70+
"CastToTyped",
71+
"ConcatItemsD",
72+
"ConcatItemsDict",
73+
"ConcatItemsd",
74+
"ConvertToMultiChannelBasedOnBratsClassesD",
75+
"ConvertToMultiChannelBasedOnBratsClassesDict",
76+
"ConvertToMultiChannelBasedOnBratsClassesd",
77+
"CopyItemsD",
78+
"CopyItemsDict",
79+
"CopyItemsd",
80+
"DataStatsD",
81+
"DataStatsDict",
82+
"DataStatsd",
83+
"DeleteItemsD",
84+
"DeleteItemsDict",
85+
"DeleteItemsd",
9286
"EnsureChannelFirstD",
9387
"EnsureChannelFirstDict",
88+
"EnsureChannelFirstd",
89+
"FgBgToIndicesD",
90+
"FgBgToIndicesDict",
91+
"FgBgToIndicesd",
92+
"IdentityD",
93+
"IdentityDict",
94+
"Identityd",
95+
"LabelToMaskD",
96+
"LabelToMaskDict",
97+
"LabelToMaskd",
98+
"LambdaD",
99+
"LambdaDict",
100+
"Lambdad",
101+
"MapLabelValueD",
102+
"MapLabelValueDict",
103+
"MapLabelValued",
94104
"RandLambdaD",
95105
"RandLambdaDict",
96-
"RepeatChannelD",
97-
"RepeatChannelDict",
106+
"RandLambdad",
107+
"RandTorchVisionD",
108+
"RandTorchVisionDict",
109+
"RandTorchVisiond",
98110
"RemoveRepeatedChannelD",
99111
"RemoveRepeatedChannelDict",
112+
"RemoveRepeatedChanneld",
113+
"RepeatChannelD",
114+
"RepeatChannelDict",
115+
"RepeatChanneld",
116+
"SelectItemsD",
117+
"SelectItemsDict",
118+
"SelectItemsd",
119+
"SimulateDelayD",
120+
"SimulateDelayDict",
121+
"SimulateDelayd",
100122
"SplitChannelD",
101123
"SplitChannelDict",
102-
"CastToTypeD",
103-
"CastToTypeDict",
104-
"ToTensorD",
105-
"ToTensorDict",
106-
"DeleteItemsD",
107-
"DeleteItemsDict",
124+
"SplitChanneld",
108125
"SqueezeDimD",
109126
"SqueezeDimDict",
110-
"DataStatsD",
111-
"DataStatsDict",
112-
"SimulateDelayD",
113-
"SimulateDelayDict",
114-
"CopyItemsD",
115-
"CopyItemsDict",
116-
"ConcatItemsD",
117-
"ConcatItemsDict",
118-
"LambdaD",
119-
"LambdaDict",
120-
"LabelToMaskD",
121-
"LabelToMaskDict",
122-
"FgBgToIndicesD",
123-
"FgBgToIndicesDict",
124-
"ConvertToMultiChannelBasedOnBratsClassesD",
125-
"ConvertToMultiChannelBasedOnBratsClassesDict",
126-
"AddExtremePointsChannelD",
127-
"AddExtremePointsChannelDict",
127+
"SqueezeDimd",
128+
"ToNumpyD",
129+
"ToNumpyDict",
130+
"ToNumpyd",
131+
"ToPILD",
132+
"ToPILDict",
133+
"ToPILd",
134+
"ToTensorD",
135+
"ToTensorDict",
136+
"ToTensord",
128137
"TorchVisionD",
129138
"TorchVisionDict",
130-
"RandTorchVisionD",
131-
"RandTorchVisionDict",
132-
"MapLabelValueD",
133-
"MapLabelValueDict",
139+
"TorchVisiond",
134140
]
135141

136142

@@ -1062,6 +1068,7 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
10621068
ToNumpyD = ToNumpyDict = ToNumpyd
10631069
ToPILD = ToPILDict = ToPILd
10641070
DeleteItemsD = DeleteItemsDict = DeleteItemsd
1071+
SelectItemsD = SelectItemsDict = SelectItemsd
10651072
SqueezeDimD = SqueezeDimDict = SqueezeDimd
10661073
DataStatsD = DataStatsDict = DataStatsd
10671074
SimulateDelayD = SimulateDelayDict = SimulateDelayd

monai/transforms/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,13 @@
4242
"img_bounds",
4343
"in_bounds",
4444
"is_empty",
45+
"is_positive",
4546
"zero_margins",
4647
"rescale_array",
4748
"rescale_instance_array",
4849
"rescale_array_int_max",
4950
"copypaste_arrays",
51+
"compute_divisible_spatial_size",
5052
"resize_center",
5153
"map_binary_to_indices",
5254
"weighted_patch_samples",
@@ -97,6 +99,13 @@ def is_empty(img: Union[np.ndarray, torch.Tensor]) -> bool:
9799
return not (img.max() > img.min()) # use > instead of <= so that an image full of NaNs will result in True
98100

99101

102+
def is_positive(img):
103+
"""
104+
Returns a boolean version of `img` where the positive values are converted into True, the other values are False.
105+
"""
106+
return img > 0
107+
108+
100109
def zero_margins(img: np.ndarray, margin: int) -> bool:
101110
"""
102111
Returns True if the values within `margin` indices of the edges of `img` in dimensions 1 and 2 are 0.
@@ -526,7 +535,7 @@ def create_translate(spatial_dims: int, shift: Union[Sequence[float], float]) ->
526535

527536
def generate_spatial_bounding_box(
528537
img: np.ndarray,
529-
select_fn: Callable = lambda x: x > 0,
538+
select_fn: Callable = is_positive,
530539
channel_indices: Optional[IndexSelection] = None,
531540
margin: Union[Sequence[int], int] = 0,
532541
) -> Tuple[List[int], List[int]]:

monai/utils/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"Method",
3131
"InverseKeys",
3232
"CommonKeys",
33+
"ForwardMode",
3334
]
3435

3536

0 commit comments

Comments
 (0)