Skip to content

Commit c82b5cf

Browse files
Nic-Mawyli
authored andcommitted
2648 Add LongestRescale transform (Project-MONAI#2662)
* [DLMED] init the transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] update doc-string Signed-off-by: Nic Ma <[email protected]> * [DLMED] complete array transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] add dict transform and inverse tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix mypy type Signed-off-by: Nic Ma <[email protected]> * [DLMED] change to enhance Resize transform Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix CI tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] update according to comments Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix TTA Signed-off-by: Nic Ma <[email protected]> * [DLMED] remove tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix mypy error Signed-off-by: Nic Ma <[email protected]>
1 parent 5d9f7c3 commit c82b5cf

File tree

5 files changed

+81
-25
lines changed

5 files changed

+81
-25
lines changed

monai/transforms/spatial/array.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import warnings
17+
from math import ceil
1718
from typing import Any, List, Optional, Sequence, Tuple, Union
1819

1920
import numpy as np
@@ -340,6 +341,11 @@ class Resize(Transform):
340341
if some components of the `spatial_size` are non-positive values, the transform will use the
341342
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
342343
to `(32, 64)` if the second spatial dimension size of img is `64`.
344+
size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims,
345+
if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`,
346+
which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:
347+
https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/
348+
#albumentations.augmentations.geometric.resize.LongestMaxSize.
343349
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
344350
The interpolation mode. Defaults to ``"area"``.
345351
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
@@ -351,10 +357,12 @@ class Resize(Transform):
351357
def __init__(
352358
self,
353359
spatial_size: Union[Sequence[int], int],
360+
size_mode: str = "all",
354361
mode: Union[InterpolateMode, str] = InterpolateMode.AREA,
355362
align_corners: Optional[bool] = None,
356363
) -> None:
357-
self.spatial_size = ensure_tuple(spatial_size)
364+
self.size_mode = look_up_option(size_mode, ["all", "longest"])
365+
self.spatial_size = spatial_size
358366
self.mode: InterpolateMode = look_up_option(mode, InterpolateMode)
359367
self.align_corners = align_corners
360368

@@ -378,20 +386,27 @@ def __call__(
378386
ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.
379387
380388
"""
381-
input_ndim = img.ndim - 1 # spatial ndim
382-
output_ndim = len(self.spatial_size)
383-
if output_ndim > input_ndim:
384-
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
385-
img = img.reshape(input_shape)
386-
elif output_ndim < input_ndim:
387-
raise ValueError(
388-
"len(spatial_size) must be greater or equal to img spatial dimensions, "
389-
f"got spatial_size={output_ndim} img={input_ndim}."
390-
)
391-
spatial_size = fall_back_tuple(self.spatial_size, img.shape[1:])
389+
if self.size_mode == "all":
390+
input_ndim = img.ndim - 1 # spatial ndim
391+
output_ndim = len(ensure_tuple(self.spatial_size))
392+
if output_ndim > input_ndim:
393+
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
394+
img = img.reshape(input_shape)
395+
elif output_ndim < input_ndim:
396+
raise ValueError(
397+
"len(spatial_size) must be greater or equal to img spatial dimensions, "
398+
f"got spatial_size={output_ndim} img={input_ndim}."
399+
)
400+
spatial_size_ = fall_back_tuple(self.spatial_size, img.shape[1:])
401+
else: # for the "longest" mode
402+
img_size = img.shape[1:]
403+
if not isinstance(self.spatial_size, int):
404+
raise ValueError("spatial_size must be an int number if size_mode is 'longest'.")
405+
scale = self.spatial_size / max(img_size)
406+
spatial_size_ = tuple(ceil(s * scale) for s in img_size)
392407
resized = torch.nn.functional.interpolate( # type: ignore
393408
input=torch.as_tensor(np.ascontiguousarray(img), dtype=torch.float).unsqueeze(0),
394-
size=spatial_size,
409+
size=spatial_size_,
395410
mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value,
396411
align_corners=self.align_corners if align_corners is None else align_corners,
397412
)

monai/transforms/spatial/dictionary.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,11 @@ class Resized(MapTransform, InvertibleTransform):
503503
if some components of the `spatial_size` are non-positive values, the transform will use the
504504
corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
505505
to `(32, 64)` if the second spatial dimension size of img is `64`.
506+
size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims,
507+
if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`,
508+
which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:
509+
https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/
510+
#albumentations.augmentations.geometric.resize.LongestMaxSize.
506511
mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
507512
The interpolation mode. Defaults to ``"area"``.
508513
See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate
@@ -518,14 +523,15 @@ def __init__(
518523
self,
519524
keys: KeysCollection,
520525
spatial_size: Union[Sequence[int], int],
526+
size_mode: str = "all",
521527
mode: InterpolateModeSequence = InterpolateMode.AREA,
522528
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
523529
allow_missing_keys: bool = False,
524530
) -> None:
525531
super().__init__(keys, allow_missing_keys)
526532
self.mode = ensure_tuple_rep(mode, len(self.keys))
527533
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
528-
self.resizer = Resize(spatial_size=spatial_size)
534+
self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode)
529535

530536
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
531537
d = dict(data)
@@ -549,7 +555,11 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
549555
mode = transform[InverseKeys.EXTRA_INFO]["mode"]
550556
align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"]
551557
# Create inverse transform
552-
inverse_transform = Resize(orig_size, mode, None if align_corners == "none" else align_corners)
558+
inverse_transform = Resize(
559+
spatial_size=orig_size,
560+
mode=mode,
561+
align_corners=None if align_corners == "none" else align_corners,
562+
)
553563
# Apply inverse transform
554564
d[key] = inverse_transform(d[key])
555565
# Remove the applied transform

tests/test_inverse.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,6 @@
249249
)
250250
)
251251

252-
TESTS.append(
253-
(
254-
"Flipd 3d",
255-
"3D",
256-
0,
257-
Flipd(KEYS, [1, 2]),
258-
)
259-
)
260-
261252
TESTS.append(
262253
(
263254
"RandFlipd 3d",
@@ -319,6 +310,10 @@
319310

320311
TESTS.append(("Resized 3d", "3D", 5e-2, Resized(KEYS, [201, 150, 78])))
321312

313+
TESTS.append(("Resized longest 2d", "2D", 2e-1, Resized(KEYS, 47, "longest", "area")))
314+
315+
TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True)))
316+
322317

323318
TESTS.append(
324319
(

tests/test_resize.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
from monai.transforms import Resize
1919
from tests.utils import NumpyImageTestCase2D
2020

21+
TEST_CASE_0 = [{"spatial_size": 15}, (6, 11, 15)]
22+
23+
TEST_CASE_1 = [{"spatial_size": 15, "mode": "area"}, (6, 11, 15)]
24+
25+
TEST_CASE_2 = [{"spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)]
26+
2127

2228
class TestResize(NumpyImageTestCase2D):
2329
def test_invalid_inputs(self):
@@ -50,6 +56,13 @@ def test_correct_results(self, spatial_size, mode):
5056
out = resize(self.imt[0])
5157
np.testing.assert_allclose(out, expected, atol=0.9)
5258

59+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2])
60+
def test_longest_shape(self, input_param, expected_shape):
61+
input_data = np.random.randint(0, 2, size=[3, 4, 7, 10])
62+
input_param["size_mode"] = "longest"
63+
result = Resize(**input_param)(input_data)
64+
np.testing.assert_allclose(result.shape[1:], expected_shape)
65+
5366

5467
if __name__ == "__main__":
5568
unittest.main()

tests/test_resized.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@
1818
from monai.transforms import Resized
1919
from tests.utils import NumpyImageTestCase2D
2020

21+
TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 11, 15)]
22+
23+
TEST_CASE_1 = [{"keys": "img", "spatial_size": 15, "mode": "area"}, (6, 11, 15)]
24+
25+
TEST_CASE_2 = [{"keys": "img", "spatial_size": 6, "mode": "trilinear", "align_corners": True}, (3, 5, 6)]
26+
27+
TEST_CASE_3 = [
28+
{"keys": ["img", "label"], "spatial_size": 6, "mode": ["trilinear", "nearest"], "align_corners": [True, None]},
29+
(3, 5, 6),
30+
]
31+
2132

2233
class TestResized(NumpyImageTestCase2D):
2334
def test_invalid_inputs(self):
@@ -31,7 +42,7 @@ def test_invalid_inputs(self):
3142

3243
@parameterized.expand([((32, -1), "area"), ((64, 64), "area"), ((32, 32, 32), "area"), ((256, 256), "bilinear")])
3344
def test_correct_results(self, spatial_size, mode):
34-
resize = Resized("img", spatial_size, mode)
45+
resize = Resized("img", spatial_size, mode=mode)
3546
_order = 0
3647
if mode.endswith("linear"):
3748
_order = 1
@@ -48,6 +59,18 @@ def test_correct_results(self, spatial_size, mode):
4859
out = resize({"img": self.imt[0]})["img"]
4960
np.testing.assert_allclose(out, expected, atol=0.9)
5061

62+
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
63+
def test_longest_shape(self, input_param, expected_shape):
64+
input_data = {
65+
"img": np.random.randint(0, 2, size=[3, 4, 7, 10]),
66+
"label": np.random.randint(0, 2, size=[3, 4, 7, 10]),
67+
}
68+
input_param["size_mode"] = "longest"
69+
rescaler = Resized(**input_param)
70+
result = rescaler(input_data)
71+
for k in rescaler.keys:
72+
np.testing.assert_allclose(result[k].shape[1:], expected_shape)
73+
5174

5275
if __name__ == "__main__":
5376
unittest.main()

0 commit comments

Comments
 (0)