diff --git a/docs/source/apps.rst b/docs/source/apps.rst index 1a2efeff48..11d60767ec 100644 --- a/docs/source/apps.rst +++ b/docs/source/apps.rst @@ -110,3 +110,11 @@ Clara MMARs :members: .. autoclass:: NormalizeHEStainsd :members: + +.. automodule:: monai.apps.pathology.transforms.spatial.array +.. autoclass:: SplitOnGrid + :members: + +.. automodule:: monai.apps.pathology.transforms.spatial.dictionary +.. autoclass:: SplitOnGridd + :members: diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 0df016244b..1be96b8e34 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,6 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .spatial.array import SplitOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py new file mode 100644 index 0000000000..07ba222ab0 --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .array import SplitOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py new file mode 100644 index 0000000000..53e0c63715 --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -0,0 +1,77 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch + +from monai.transforms.transform import Transform + +__all__ = ["SplitOnGrid"] + + +class SplitOnGrid(Transform): + """ + Split the image into patches based on the provided grid shape. + This transform works only with torch.Tensor inputs. + + Args: + grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + If it's an integer, the value will be repeated for each dimension. Default is 2x2 + patch_size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is (0, 0), where the patch size will be infered from the grid shape. + + Note: the shape of the input image is infered based on the first image used. + """ + + def __init__( + self, + grid_size: Union[int, Tuple[int, int]] = (2, 2), + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + ): + # Grid size + if isinstance(grid_size, int): + self.grid_size = (grid_size, grid_size) + else: + self.grid_size = grid_size + # Patch size + self.patch_size = None + if isinstance(patch_size, int): + self.patch_size = (patch_size, patch_size) + else: + self.patch_size = patch_size + + def __call__(self, image: torch.Tensor) -> torch.Tensor: + if self.grid_size == (1, 1) and self.patch_size is None: + return torch.stack([image]) + patch_size, steps = self.get_params(image.shape[1:]) + patches = ( + image.unfold(1, patch_size[0], steps[0]) + .unfold(2, patch_size[1], steps[1]) + .flatten(1, 2) + .transpose(0, 1) + .contiguous() + ) + return patches + + def get_params(self, image_size): + if self.patch_size is None: + patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2)) + else: + patch_size = self.patch_size + + steps = tuple( + (image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i] + for i in range(2) + ) + + return patch_size, steps diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py new file mode 100644 index 0000000000..10b01a39de --- /dev/null +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -0,0 +1,56 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Hashable, Mapping, Optional, Tuple, Union + +import torch + +from monai.config import KeysCollection +from monai.transforms.transform import MapTransform + +from .array import SplitOnGrid + +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] + + +class SplitOnGridd(MapTransform): + """ + Split the image into patches based on the provided grid shape. + This transform works only with torch.Tensor inputs. + + Args: + grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches. + If it's an integer, the value will be repeated for each dimension. Default is 2x2 + patch_size: a tuple or an integer that defines the output patch sizes. + If it's an integer, the value will be repeated for each dimension. + The default is (0, 0), where the patch size will be infered from the grid shape. + + Note: the shape of the input image is infered based on the first image used. + """ + + def __init__( + self, + keys: KeysCollection, + grid_size: Union[int, Tuple[int, int]] = (2, 2), + patch_size: Optional[Union[int, Tuple[int, int]]] = None, + allow_missing_keys: bool = False, + ): + super().__init__(keys, allow_missing_keys) + self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.splitter(d[key]) + return d + + +SplitOnGridDict = SplitOnGridD = SplitOnGridd diff --git a/tests/test_split_on_grid.py b/tests/test_split_on_grid.py new file mode 100644 index 0000000000..a187835e7b --- /dev/null +++ b/tests/test_split_on_grid.py @@ -0,0 +1,131 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import SplitOnGrid + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [ + {"grid_size": (2, 2)}, + A, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_1 = [ + {"grid_size": (2, 1)}, + A, + torch.stack([A1, A2]), +] + +TEST_CASE_2 = [ + {"grid_size": (1, 2)}, + A1, + torch.stack([A11, A12]), +] + +TEST_CASE_3 = [ + {"grid_size": (1, 2)}, + A2, + torch.stack([A21, A22]), +] + +TEST_CASE_4 = [ + {"grid_size": (1, 1), "patch_size": (2, 2)}, + A, + torch.stack([A11]), +] + +TEST_CASE_5 = [ + {"grid_size": 1, "patch_size": 4}, + A, + torch.stack([A]), +] + +TEST_CASE_6 = [ + {"grid_size": 2, "patch_size": 2}, + A, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_7 = [ + {"grid_size": 1}, + A, + torch.stack([A]), +] + +TEST_CASE_MC_0 = [ + {"grid_size": (2, 2)}, + [A, A], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] + + +TEST_CASE_MC_1 = [ + {"grid_size": (2, 1)}, + [A] * 5, + [torch.stack([A1, A2])] * 5, +] + + +TEST_CASE_MC_2 = [ + {"grid_size": (1, 2)}, + [A1, A2], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + + +class TestSplitOnGrid(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + ] + ) + def test_split_pathce_single_call(self, input_parameters, img, expected): + splitter = SplitOnGrid(**input_parameters) + output = splitter(img) + np.testing.assert_equal(output.numpy(), expected.numpy()) + + @parameterized.expand( + [ + TEST_CASE_MC_0, + TEST_CASE_MC_1, + TEST_CASE_MC_2, + ] + ) + def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + splitter = SplitOnGrid(**input_parameters) + for img, expected in zip(img_list, expected_list): + output = splitter(img) + np.testing.assert_equal(output.numpy(), expected.numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_split_on_grid_dict.py b/tests/test_split_on_grid_dict.py new file mode 100644 index 0000000000..96ec095423 --- /dev/null +++ b/tests/test_split_on_grid_dict.py @@ -0,0 +1,131 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.apps.pathology.transforms import SplitOnGridDict + +A11 = torch.randn(3, 2, 2) +A12 = torch.randn(3, 2, 2) +A21 = torch.randn(3, 2, 2) +A22 = torch.randn(3, 2, 2) + +A1 = torch.cat([A11, A12], 2) +A2 = torch.cat([A21, A22], 2) +A = torch.cat([A1, A2], 1) + +TEST_CASE_0 = [ + {"keys": "image", "grid_size": (2, 2)}, + {"image": A}, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_1 = [ + {"keys": "image", "grid_size": (2, 1)}, + {"image": A}, + torch.stack([A1, A2]), +] + +TEST_CASE_2 = [ + {"keys": "image", "grid_size": (1, 2)}, + {"image": A1}, + torch.stack([A11, A12]), +] + +TEST_CASE_3 = [ + {"keys": "image", "grid_size": (1, 2)}, + {"image": A2}, + torch.stack([A21, A22]), +] + +TEST_CASE_4 = [ + {"keys": "image", "grid_size": (1, 1), "patch_size": (2, 2)}, + {"image": A}, + torch.stack([A11]), +] + +TEST_CASE_5 = [ + {"keys": "image", "grid_size": 1, "patch_size": 4}, + {"image": A}, + torch.stack([A]), +] + +TEST_CASE_6 = [ + {"keys": "image", "grid_size": 2, "patch_size": 2}, + {"image": A}, + torch.stack([A11, A12, A21, A22]), +] + +TEST_CASE_7 = [ + {"keys": "image", "grid_size": 1}, + {"image": A}, + torch.stack([A]), +] + +TEST_CASE_MC_0 = [ + {"keys": "image", "grid_size": (2, 2)}, + [{"image": A}, {"image": A}], + [torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])], +] + + +TEST_CASE_MC_1 = [ + {"keys": "image", "grid_size": (2, 1)}, + [{"image": A}] * 5, + [torch.stack([A1, A2])] * 5, +] + + +TEST_CASE_MC_2 = [ + {"keys": "image", "grid_size": (1, 2)}, + [{"image": A1}, {"image": A2}], + [torch.stack([A11, A12]), torch.stack([A21, A22])], +] + + +class TestSplitOnGridDict(unittest.TestCase): + @parameterized.expand( + [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + ] + ) + def test_split_pathce_single_call(self, input_parameters, img_dict, expected): + splitter = SplitOnGridDict(**input_parameters) + output = splitter(img_dict)[input_parameters["keys"]] + np.testing.assert_equal(output.numpy(), expected.numpy()) + + @parameterized.expand( + [ + TEST_CASE_MC_0, + TEST_CASE_MC_1, + TEST_CASE_MC_2, + ] + ) + def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list): + splitter = SplitOnGridDict(**input_parameters) + for img_dict, expected in zip(img_list, expected_list): + output = splitter(img_dict)[input_parameters["keys"]] + np.testing.assert_equal(output.numpy(), expected.numpy()) + + +if __name__ == "__main__": + unittest.main()