Skip to content

Split On Grid #2879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 3, 2021
8 changes: 8 additions & 0 deletions docs/source/apps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,11 @@ Clara MMARs
:members:
.. autoclass:: NormalizeHEStainsd
:members:

.. automodule:: monai.apps.pathology.transforms.spatial.array
.. autoclass:: SplitOnGrid
:members:

.. automodule:: monai.apps.pathology.transforms.spatial.dictionary
.. autoclass:: SplitOnGridd
:members:
2 changes: 2 additions & 0 deletions monai/apps/pathology/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .spatial.array import SplitOnGrid
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
from .stain.array import ExtractHEStains, NormalizeHEStains
from .stain.dictionary import (
ExtractHEStainsd,
Expand Down
13 changes: 13 additions & 0 deletions monai/apps/pathology/transforms/spatial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .array import SplitOnGrid
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
77 changes: 77 additions & 0 deletions monai/apps/pathology/transforms/spatial/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union

import torch

from monai.transforms.transform import Transform

__all__ = ["SplitOnGrid"]


class SplitOnGrid(Transform):
"""
Split the image into patches based on the provided grid shape.
This transform works only with torch.Tensor inputs.

Args:
grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches.
If it's an integer, the value will be repeated for each dimension. Default is 2x2
patch_size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is (0, 0), where the patch size will be infered from the grid shape.

Note: the shape of the input image is infered based on the first image used.
"""

def __init__(
self,
grid_size: Union[int, Tuple[int, int]] = (2, 2),
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
):
# Grid size
if isinstance(grid_size, int):
self.grid_size = (grid_size, grid_size)
else:
self.grid_size = grid_size
# Patch size
self.patch_size = None
if isinstance(patch_size, int):
self.patch_size = (patch_size, patch_size)
else:
self.patch_size = patch_size

def __call__(self, image: torch.Tensor) -> torch.Tensor:
if self.grid_size == (1, 1) and self.patch_size is None:
return torch.stack([image])
patch_size, steps = self.get_params(image.shape[1:])
patches = (
image.unfold(1, patch_size[0], steps[0])
.unfold(2, patch_size[1], steps[1])
.flatten(1, 2)
.transpose(0, 1)
.contiguous()
)
return patches

def get_params(self, image_size):
if self.patch_size is None:
patch_size = tuple(image_size[i] // self.grid_size[i] for i in range(2))
else:
patch_size = self.patch_size

steps = tuple(
(image_size[i] - patch_size[i]) // (self.grid_size[i] - 1) if self.grid_size[i] > 1 else image_size[i]
for i in range(2)
)

return patch_size, steps
56 changes: 56 additions & 0 deletions monai/apps/pathology/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Hashable, Mapping, Optional, Tuple, Union

import torch

from monai.config import KeysCollection
from monai.transforms.transform import MapTransform

from .array import SplitOnGrid

__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]


class SplitOnGridd(MapTransform):
"""
Split the image into patches based on the provided grid shape.
This transform works only with torch.Tensor inputs.

Args:
grid_shape: a tuple or an integer define the shape of the grid upon which to extract patches.
If it's an integer, the value will be repeated for each dimension. Default is 2x2
patch_size: a tuple or an integer that defines the output patch sizes.
If it's an integer, the value will be repeated for each dimension.
The default is (0, 0), where the patch size will be infered from the grid shape.

Note: the shape of the input image is infered based on the first image used.
"""

def __init__(
self,
keys: KeysCollection,
grid_size: Union[int, Tuple[int, int]] = (2, 2),
patch_size: Optional[Union[int, Tuple[int, int]]] = None,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
self.splitter = SplitOnGrid(grid_size=grid_size, patch_size=patch_size)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.splitter(d[key])
return d


SplitOnGridDict = SplitOnGridD = SplitOnGridd
131 changes: 131 additions & 0 deletions tests/test_split_on_grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.apps.pathology.transforms import SplitOnGrid

A11 = torch.randn(3, 2, 2)
A12 = torch.randn(3, 2, 2)
A21 = torch.randn(3, 2, 2)
A22 = torch.randn(3, 2, 2)

A1 = torch.cat([A11, A12], 2)
A2 = torch.cat([A21, A22], 2)
A = torch.cat([A1, A2], 1)

TEST_CASE_0 = [
{"grid_size": (2, 2)},
A,
torch.stack([A11, A12, A21, A22]),
]

TEST_CASE_1 = [
{"grid_size": (2, 1)},
A,
torch.stack([A1, A2]),
]

TEST_CASE_2 = [
{"grid_size": (1, 2)},
A1,
torch.stack([A11, A12]),
]

TEST_CASE_3 = [
{"grid_size": (1, 2)},
A2,
torch.stack([A21, A22]),
]

TEST_CASE_4 = [
{"grid_size": (1, 1), "patch_size": (2, 2)},
A,
torch.stack([A11]),
]

TEST_CASE_5 = [
{"grid_size": 1, "patch_size": 4},
A,
torch.stack([A]),
]

TEST_CASE_6 = [
{"grid_size": 2, "patch_size": 2},
A,
torch.stack([A11, A12, A21, A22]),
]

TEST_CASE_7 = [
{"grid_size": 1},
A,
torch.stack([A]),
]

TEST_CASE_MC_0 = [
{"grid_size": (2, 2)},
[A, A],
[torch.stack([A11, A12, A21, A22]), torch.stack([A11, A12, A21, A22])],
]


TEST_CASE_MC_1 = [
{"grid_size": (2, 1)},
[A] * 5,
[torch.stack([A1, A2])] * 5,
]


TEST_CASE_MC_2 = [
{"grid_size": (1, 2)},
[A1, A2],
[torch.stack([A11, A12]), torch.stack([A21, A22])],
]


class TestSplitOnGrid(unittest.TestCase):
@parameterized.expand(
[
TEST_CASE_0,
TEST_CASE_1,
TEST_CASE_2,
TEST_CASE_3,
TEST_CASE_4,
TEST_CASE_5,
TEST_CASE_6,
TEST_CASE_7,
]
)
def test_split_pathce_single_call(self, input_parameters, img, expected):
splitter = SplitOnGrid(**input_parameters)
output = splitter(img)
np.testing.assert_equal(output.numpy(), expected.numpy())

@parameterized.expand(
[
TEST_CASE_MC_0,
TEST_CASE_MC_1,
TEST_CASE_MC_2,
]
)
def test_split_pathce_multiple_call(self, input_parameters, img_list, expected_list):
splitter = SplitOnGrid(**input_parameters)
for img, expected in zip(img_list, expected_list):
output = splitter(img)
np.testing.assert_equal(output.numpy(), expected.numpy())


if __name__ == "__main__":
unittest.main()
Loading