Skip to content

2648 Add LongestRescale transform #2662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 28 additions & 13 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""

import warnings
from math import ceil
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -340,6 +341,11 @@ class Resize(Transform):
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims,
if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`,
which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:
https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/
#albumentations.augmentations.geometric.resize.LongestMaxSize.
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode. Defaults to ``"area"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
Expand All @@ -351,10 +357,12 @@ class Resize(Transform):
def __init__(
self,
spatial_size: Union[Sequence[int], int],
size_mode: str = "all",
mode: Union[InterpolateMode, str] = InterpolateMode.AREA,
align_corners: Optional[bool] = None,
) -> None:
self.spatial_size = ensure_tuple(spatial_size)
self.size_mode = look_up_option(size_mode, ["all", "longest"])
self.spatial_size = spatial_size
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
self.align_corners = align_corners

Expand All @@ -378,20 +386,27 @@ def __call__(
ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.

"""
input_ndim = img.ndim - 1 # spatial ndim
output_ndim = len(self.spatial_size)
if output_ndim > input_ndim:
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
img = img.reshape(input_shape)
elif output_ndim < input_ndim:
raise ValueError(
"len(spatial_size) must be greater or equal to img spatial dimensions, "
f"got spatial_size={output_ndim} img={input_ndim}."
)
spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:])
if self.size_mode == "all":
input_ndim = img.ndim - 1 # spatial ndim
output_ndim = len(ensure_tuple(self.spatial_size))
if output_ndim > input_ndim:
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
img = img.reshape(input_shape)
elif output_ndim < input_ndim:
raise ValueError(
"len(spatial_size) must be greater or equal to img spatial dimensions, "
f"got spatial_size={output_ndim} img={input_ndim}."
)
spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:])
else: # for the "longest" mode
img_size = img.shape[1:]
if not isinstance(self.spatial_size, int):
raise ValueError("spatial_size must be an int number if size_mode is 'longest'.")
scale = self.spatial_size / max(img_size)
spatial_size_ = tuple(ceil(s * scale) for s in img_size)
resized = torch.nn.functional.interpolate( # type: ignore
input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0),
size=spatial_size,
size=spatial_size_,
mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value,
align_corners=self.align_corners if align_corners is None else align_corners,
)
Expand Down
14 changes: 12 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,11 @@ class Resized(MapTransform, InvertibleTransform):
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims,
if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`,
which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:
https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/
#albumentations.augmentations.geometric.resize.LongestMaxSize.
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
The interpolation mode. Defaults to ``"area"``.
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
Expand All @@ -518,14 +523,15 @@ def __init__(
self,
keys: KeysCollection,
spatial_size: Union[Sequence[int], int],
size_mode: str = "all",
mode: InterpolateModeSequence = InterpolateMode.AREA,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.mode = ensure_tuple_rep(mode, len(self.keys))
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
self.resizer = Resize(spatial_size=spatial_size)
self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
d = dict(data)
Expand All @@ -549,7 +555,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
mode = transform[InverseKeys.EXTRA_INFO]["mode"]
align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"]
# Create inverse transform
inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners)
inverse_transform = Resize(
spatial_size=orig_size,
mode=mode,
align_corners=None if align_corners == "none" else align_corners,
)
# Apply inverse transform
d[key] = inverse_transform(d[key])
# Remove the applied transform
Expand Down
13 changes: 4 additions & 9 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,6 @@
)
)

TESTS.append(
(
"Flipd 3d",
"3D",
0,
Flipd(KEYS, [1, 2]),
)
)

TESTS.append(
(
"RandFlipd 3d",
Expand Down Expand Up @@ -319,6 +310,10 @@

TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78])))

TESTS.append(("Resized longest 2d", "2D", 2e-1, Resized(KEYS, 47, "longest", "area")))

TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True)))


TESTS.append(
(
Expand Down
13 changes: 13 additions & 0 deletions tests/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from monai.transforms import Resize
from tests.utils import NumpyImageTestCase2D

TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)]

TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)]

TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)]


class TestResize(NumpyImageTestCase2D):
def test_invalid_inputs(self):
Expand Down Expand Up @@ -50,6 +56,13 @@ def test_correct_results(self, spatial_size, mode):
out = resize(self.imt[0])
np.testing.assert_allclose(out, expected, atol=0.9)

@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
def test_longest_shape(self, input_param, expected_shape):
input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
input_param["size_mode"] = "longest"
result = Resize(**input_param)(input_data)
np.testing.assert_allclose(result.shape[1:], expected_shape)


if __name__ == "__main__":
unittest.main()
25 changes: 24 additions & 1 deletion tests/test_resized.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@
from monai.transforms import Resized
from tests.utils import NumpyImageTestCase2D

TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)]

TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)]

TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)]

TEST_CASE_3 = [
{"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]},
(3, 5, 6),
]


class TestResized(NumpyImageTestCase2D):
def test_invalid_inputs(self):
Expand All @@ -31,7 +42,7 @@ def test_invalid_inputs(self):

@parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")])
def test_correct_results(self, spatial_size, mode):
resize = Resized("img", spatial_size, mode)
resize = Resized("img", spatial_size, mode=mode)
_order = 0
if mode.endswith("linear"):
_order = 1
Expand All @@ -48,6 +59,18 @@ def test_correct_results(self, spatial_size, mode):
out = resize({"img": self.imt[0]})["img"]
np.testing.assert_allclose(out, expected, atol=0.9)

@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_longest_shape(self, input_param, expected_shape):
input_data = {
"img": np.random.randint(0, 2, size=[3, 4, 7, 10]),
"label": np.random.randint(0, 2, size=[3, 4, 7, 10]),
}
input_param["size_mode"] = "longest"
rescaler = Resized(**input_param)
result = rescaler(input_data)
for k in rescaler.keys:
np.testing.assert_allclose(result[k].shape[1:], expected_shape)


if __name__ == "__main__":
unittest.main()