Skip to content

Commit 2c7ec57

Browse files
authored
[DLMED] enhance CropForegroundd transform (#2808)
Signed-off-by: Nic Ma <[email protected]>
1 parent 7659869 commit 2c7ec57

File tree

3 files changed

+22
-11
lines changed

3 files changed

+22
-11
lines changed

monai/transforms/croppad/array.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,13 @@ def compute_bounding_box(self, img: np.ndarray):
691691
box_end_ = box_start_ + spatial_size
692692
return box_start_, box_end_
693693

694-
def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray):
694+
def crop_pad(
695+
self,
696+
img: np.ndarray,
697+
box_start: np.ndarray,
698+
box_end: np.ndarray,
699+
mode: Optional[Union[NumpyPadMode, str]] = None,
700+
):
695701
"""
696702
Crop and pad based on the bounding box.
697703
@@ -700,15 +706,15 @@ def crop_pad(self, img: np.ndarray, box_start: np.ndarray, box_end: np.ndarray):
700706
pad_to_start = np.maximum(-box_start, 0)
701707
pad_to_end = np.maximum(box_end - np.asarray(img.shape[1:]), 0)
702708
pad = list(chain(*zip(pad_to_start.tolist(), pad_to_end.tolist())))
703-
return BorderPad(spatial_border=pad, mode=self.mode, **self.np_kwargs)(cropped)
709+
return BorderPad(spatial_border=pad, mode=mode or self.mode, **self.np_kwargs)(cropped)
704710

705-
def __call__(self, img: np.ndarray):
711+
def __call__(self, img: np.ndarray, mode: Optional[Union[NumpyPadMode, str]] = None):
706712
"""
707713
Apply the transform to `img`, assuming `img` is channel-first and
708714
slicing doesn't change the channel dim.
709715
"""
710716
box_start, box_end = self.compute_bounding_box(img)
711-
cropped = self.crop_pad(img, box_start, box_end)
717+
cropped = self.crop_pad(img, box_start, box_end, mode)
712718

713719
if self.return_coords:
714720
return cropped, box_start, box_end

monai/transforms/croppad/dictionary.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def __init__(
797797
channel_indices: Optional[IndexSelection] = None,
798798
margin: Union[Sequence[int], int] = 0,
799799
k_divisible: Union[Sequence[int], int] = 1,
800-
mode: Union[NumpyPadMode, str] = NumpyPadMode.CONSTANT,
800+
mode: NumpyPadModeSequence = NumpyPadMode.CONSTANT,
801801
start_coord_key: str = "foreground_start_coord",
802802
end_coord_key: str = "foreground_end_coord",
803803
allow_missing_keys: bool = False,
@@ -818,6 +818,7 @@ def __init__(
818818
``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
819819
one of the listed string values or a user supplied function. Defaults to ``"constant"``.
820820
see also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
821+
it also can be a sequence of string, each element corresponds to a key in ``keys``.
821822
start_coord_key: key to record the start coordinate of spatial bounding box for foreground.
822823
end_coord_key: key to record the end coordinate of spatial bounding box for foreground.
823824
allow_missing_keys: don't raise exception if key is missing.
@@ -834,18 +835,18 @@ def __init__(
834835
channel_indices=channel_indices,
835836
margin=margin,
836837
k_divisible=k_divisible,
837-
mode=mode,
838838
**np_kwargs,
839839
)
840+
self.mode = ensure_tuple_rep(mode, len(self.keys))
840841

841842
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
842843
d = dict(data)
843844
box_start, box_end = self.cropper.compute_bounding_box(img=d[self.source_key])
844845
d[self.start_coord_key] = box_start
845846
d[self.end_coord_key] = box_end
846-
for key in self.key_iterator(d):
847+
for key, m in self.key_iterator(d, self.mode):
847848
self.push_transform(d, key, extra_info={"box_start": box_start, "box_end": box_end})
848-
d[key] = self.cropper.crop_pad(d[key], box_start, box_end)
849+
d[key] = self.cropper.crop_pad(img=d[key], box_start=box_start, box_end=box_end, mode=m)
849850
return d
850851

851852
def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:

tests/test_crop_foregroundd.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from parameterized import parameterized
1616

1717
from monai.transforms import CropForegroundd
18+
from monai.utils import NumpyPadMode
1819

1920
TEST_CASE_1 = [
2021
{
@@ -59,15 +60,18 @@
5960

6061
TEST_CASE_6 = [
6162
{
62-
"keys": ["img"],
63+
"keys": ["img", "seg"],
6364
"source_key": "img",
6465
"select_fn": lambda x: x > 0,
6566
"channel_indices": 0,
6667
"margin": 0,
6768
"k_divisible": [4, 6],
68-
"mode": "edge",
69+
"mode": ["edge", NumpyPadMode.CONSTANT],
70+
},
71+
{
72+
"img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]),
73+
"seg": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]]),
6974
},
70-
{"img": np.array([[[0, 2, 1, 2, 0], [1, 1, 2, 1, 1], [2, 2, 3, 2, 2], [1, 1, 2, 1, 1], [0, 0, 0, 0, 0]]])},
7175
np.array([[[0, 2, 1, 2, 0, 0], [1, 1, 2, 1, 1, 1], [2, 2, 3, 2, 2, 2], [1, 1, 2, 1, 1, 1]]]),
7276
]
7377

0 commit comments

Comments
 (0)