diff --git a/monai/apps/pathology/transforms/__init__.py b/monai/apps/pathology/transforms/__init__.py index 1be96b8e34..b418e20279 100644 --- a/monai/apps/pathology/transforms/__init__.py +++ b/monai/apps/pathology/transforms/__init__.py @@ -9,8 +9,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .spatial.array import SplitOnGrid -from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .spatial.array import SplitOnGrid, TileOnGrid +from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict from .stain.array import ExtractHEStains, NormalizeHEStains from .stain.dictionary import ( ExtractHEStainsd, diff --git a/monai/apps/pathology/transforms/spatial/__init__.py b/monai/apps/pathology/transforms/spatial/__init__.py index 07ba222ab0..c9971254e7 100644 --- a/monai/apps/pathology/transforms/spatial/__init__.py +++ b/monai/apps/pathology/transforms/spatial/__init__.py @@ -9,5 +9,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .array import SplitOnGrid -from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict +from .array import SplitOnGrid, TileOnGrid +from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict diff --git a/monai/apps/pathology/transforms/spatial/array.py b/monai/apps/pathology/transforms/spatial/array.py index 4edf987610..e08ac7f46f 100644 --- a/monai/apps/pathology/transforms/spatial/array.py +++ b/monai/apps/pathology/transforms/spatial/array.py @@ -9,13 +9,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union, cast +import numpy as np import torch +from numpy.lib.stride_tricks import as_strided -from monai.transforms.transform import Transform +from monai.transforms.transform import Randomizable, Transform -__all__ = ["SplitOnGrid"] +__all__ = ["SplitOnGrid", "TileOnGrid"] class SplitOnGrid(Transform): @@ -73,3 +75,153 @@ def get_params(self, image_size): ) return patch_size, steps + + +class TileOnGrid(Randomizable, Transform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + def __init__( + self, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + ): + self.tile_count = tile_count + self.tile_size = tile_size + self.step = step + self.random_offset = random_offset + self.pad_full = pad_full + self.background_val = background_val + self.filter_mode = filter_mode + + if self.step is None: + self.step = self.tile_size # non-overlapping grid + + self.offset = (0, 0) + self.random_idxs = np.array((0,)) + + if self.filter_mode not in ["min", "max", "random"]: + raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode)) + + def randomize(self, img_size: Sequence[int]) -> None: + + c, h, w = img_size + tile_step = cast(int, self.step) + + self.offset = (0, 0) + if self.random_offset: + pad_h = h % self.tile_size + pad_w = w % self.tile_size + self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0) + h = h - self.offset[0] + w = w - self.offset[1] + + if self.pad_full: + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + h = h + pad_h + w = w + pad_w + + h_n = (h - self.tile_size + tile_step) // tile_step + w_n = (w - self.tile_size + tile_step) // tile_step + tile_total = h_n * w_n + + if self.tile_count is not None and tile_total > self.tile_count: + self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) + else: + self.random_idxs = np.array((0,)) + + def __call__(self, image: np.ndarray) -> np.ndarray: + + # add random offset + self.randomize(img_size=image.shape) + tile_step = cast(int, self.step) + + if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0): + image = image[:, self.offset[0] :, self.offset[1] :] + + # pad to full size, divisible by tile_size + if self.pad_full: + c, h, w = image.shape + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size + image = np.pad( + image, + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], + constant_values=self.background_val, + ) + + # extact tiles + xstep, ystep = tile_step, tile_step + xsize, ysize = self.tile_size, self.tile_size + clen, xlen, ylen = image.shape + cstride, xstride, ystride = image.strides + llw = as_strided( + image, + shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize), + strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride), + writeable=False, + ) + image = llw.reshape(-1, clen, xsize, ysize) + + # if keeping all patches + if self.tile_count is None: + # retain only patches with significant foreground content to speed up inference + # FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference + thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size + if self.filter_mode == "min": + # default, keep non-background tiles (small values) + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) + image = image[idxs.reshape(-1)] + elif self.filter_mode == "max": + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) + image = image[idxs.reshape(-1)] + + else: + if len(image) > self.tile_count: + + if self.filter_mode == "min": + # default, keep non-background tiles (smallest values) + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count] + image = image[idxs] + elif self.filter_mode == "max": + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :] + image = image[idxs] + else: + # random subset (more appropriate for WSIs without distinct background) + if self.random_idxs is not None: + image = image[self.random_idxs] + + elif len(image) < self.tile_count: + image = np.pad( + image, + [[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]], + constant_values=self.background_val, + ) + + return image diff --git a/monai/apps/pathology/transforms/spatial/dictionary.py b/monai/apps/pathology/transforms/spatial/dictionary.py index 10b01a39de..0168ac3108 100644 --- a/monai/apps/pathology/transforms/spatial/dictionary.py +++ b/monai/apps/pathology/transforms/spatial/dictionary.py @@ -9,16 +9,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Hashable, Mapping, Optional, Tuple, Union +import copy +from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union +import numpy as np import torch from monai.config import KeysCollection -from monai.transforms.transform import MapTransform +from monai.transforms.transform import MapTransform, Randomizable -from .array import SplitOnGrid +from .array import SplitOnGrid, TileOnGrid -__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"] +__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"] class SplitOnGridd(MapTransform): @@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class TileOnGridd(Randomizable, MapTransform): + """ + Tile the 2D image into patches on a grid and maintain a subset of it. + This transform works only with np.ndarray inputs for 2D images. + + Args: + tile_count: number of tiles to extract, if None extracts all non-background tiles + Defaults to ``None``. + tile_size: size of the square tile + Defaults to ``256``. + step: step size + Defaults to ``None`` (same as tile_size) + random_offset: Randomize position of the grid, instead of starting from the top-left corner + Defaults to ``False``. + pad_full: pad image to the size evenly divisible by tile_size + Defaults to ``False``. + background_val: the background constant (e.g. 255 for white background) + Defaults to ``255``. + filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size, + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset + Defaults to ``min`` (which assumes background is high value) + + """ + + def __init__( + self, + keys: KeysCollection, + tile_count: Optional[int] = None, + tile_size: int = 256, + step: Optional[int] = None, + random_offset: bool = False, + pad_full: bool = False, + background_val: int = 255, + filter_mode: str = "min", + allow_missing_keys: bool = False, + return_list_of_dicts: bool = False, + ): + super().__init__(keys, allow_missing_keys) + + self.return_list_of_dicts = return_list_of_dicts + self.seed = None + + self.splitter = TileOnGrid( + tile_count=tile_count, + tile_size=tile_size, + step=step, + random_offset=random_offset, + pad_full=pad_full, + background_val=background_val, + filter_mode=filter_mode, + ) + + def randomize(self, data: Any = None) -> None: + self.seed = self.R.randint(10000) # type: ignore + + def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]: + + self.randomize() + + d = dict(data) + for key in self.key_iterator(d): + self.splitter.set_random_state(seed=self.seed) # same random seed for all keys + d[key] = self.splitter(d[key]) + + if self.return_list_of_dicts: + d_list = [] + for i in range(len(d[self.keys[0]])): + d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()}) + d = d_list # type: ignore + + return d + + SplitOnGridDict = SplitOnGridD = SplitOnGridd +TileOnGridDict = TileOnGridD = TileOnGridd diff --git a/tests/test_tile_on_grid.py b/tests/test_tile_on_grid.py new file mode 100644 index 0000000000..f8c86fa90a --- /dev/null +++ b/tests/test_tile_on_grid.py @@ -0,0 +1,112 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGrid + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + } + ] + ) + + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGrid(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_patch_single_call(self, input_parameters): + + img, tiles = make_image(**input_parameters) + + tiler = TileOnGrid(**input_parameters) + output = tiler(img) + np.testing.assert_equal(output, tiles) + + @parameterized.expand(TEST_CASES2) + def test_tile_patch_random_call(self, input_parameters): + + img, tiles = make_image(**input_parameters, seed=123) + + tiler = TileOnGrid(**input_parameters) + tiler.set_random_state(seed=123) + + output = tiler(img) + np.testing.assert_equal(output, tiles) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_tile_on_grid_dict.py b/tests/test_tile_on_grid_dict.py new file mode 100644 index 0000000000..95cfa179dd --- /dev/null +++ b/tests/test_tile_on_grid_dict.py @@ -0,0 +1,137 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from typing import Optional + +import numpy as np +from parameterized import parameterized + +from monai.apps.pathology.transforms import TileOnGridDict + +TEST_CASES = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": False, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + + +TEST_CASES2 = [] +for tile_count in [16, 64]: + for tile_size in [8, 32]: + for filter_mode in ["min", "max", "random"]: + for background_val in [255, 0]: + for return_list_of_dicts in [False, True]: + TEST_CASES2.append( + [ + { + "tile_count": tile_count, + "tile_size": tile_size, + "filter_mode": filter_mode, + "random_offset": True, + "background_val": background_val, + "return_list_of_dicts": return_list_of_dicts, + } + ] + ) + + +def make_image( + tile_count: int, tile_size: int, random_offset: bool = False, filter_mode: Optional[str] = None, seed=123, **kwargs +): + + tile_count = int(np.sqrt(tile_count)) + pad = 0 + if random_offset: + pad = 3 + + image = np.random.randint(200, size=[3, tile_count * tile_size + pad, tile_count * tile_size + pad], dtype=np.uint8) + imlarge = image + + random_state = np.random.RandomState(seed) + + if random_offset: + pad_h = image.shape[1] % tile_size + pad_w = image.shape[2] % tile_size + offset = (random_state.randint(pad_h) if pad_h > 0 else 0, random_state.randint(pad_w) if pad_w > 0 else 0) + image = image[:, offset[0] :, offset[1] :] + + tiles_list = [] + for x in range(tile_count): + for y in range(tile_count): + tiles_list.append(image[:, x * tile_size : (x + 1) * tile_size, y * tile_size : (y + 1) * tile_size]) + + tiles = np.stack(tiles_list, axis=0) # type: ignore + + if (filter_mode == "min" or filter_mode == "max") and len(tiles) > tile_count ** 2: + tiles = tiles[np.argsort(tiles.sum(axis=(1, 2, 3)))] + + return imlarge, tiles + + +class TestTileOnGridDict(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_tile_patch_single_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + img, tiles = make_image(**input_parameters) + + splitter = TileOnGridDict(**input_parameters) + + output = splitter({key: img}) + + if input_parameters.get("return_list_of_dicts", False): + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + + np.testing.assert_equal(tiles, output) + + @parameterized.expand(TEST_CASES2) + def test_tile_patch_random_call(self, input_parameters): + + key = "image" + input_parameters["keys"] = key + + random_state = np.random.RandomState(123) + seed = random_state.randint(10000) + img, tiles = make_image(**input_parameters, seed=seed) + + splitter = TileOnGridDict(**input_parameters) + splitter.set_random_state(seed=123) + + output = splitter({key: img}) + + if input_parameters.get("return_list_of_dicts", False): + output = np.stack([ix[key] for ix in output], axis=0) + else: + output = output[key] + + np.testing.assert_equal(tiles, output) + + +if __name__ == "__main__": + unittest.main()