Skip to content

Commit 9ec1d14

Browse files
myronbhashemian
andauthored
MIL component to extract patches (#3237)
* MIL component to extract patches Signed-off-by: myron <[email protected]> * MIL component to extract patches Signed-off-by: myron <[email protected]> * random flag, minor fixes Signed-off-by: myron <[email protected]> * minor fixes for padding Signed-off-by: myron <[email protected]> * improve tests Signed-off-by: myron <[email protected]> Co-authored-by: Behrooz <[email protected]>
1 parent 4d83fc0 commit 9ec1d14

File tree

6 files changed

+488
-11
lines changed

6 files changed

+488
-11
lines changed

monai/apps/pathology/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .spatial.array import SplitOnGrid
13-
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .spatial.array import SplitOnGrid, TileOnGrid
13+
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
1414
from .stain.array import ExtractHEStains, NormalizeHEStains
1515
from .stain.dictionary import (
1616
ExtractHEStainsd,

monai/apps/pathology/transforms/spatial/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .array import SplitOnGrid
13-
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .array import SplitOnGrid, TileOnGrid
13+
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict

monai/apps/pathology/transforms/spatial/array.py

Lines changed: 155 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Sequence, Tuple, Union, cast
1313

14+
import numpy as np
1415
import torch
16+
from numpy.lib.stride_tricks import as_strided
1517

16-
from monai.transforms.transform import Transform
18+
from monai.transforms.transform import Randomizable, Transform
1719

18-
__all__ = ["SplitOnGrid"]
20+
__all__ = ["SplitOnGrid", "TileOnGrid"]
1921

2022

2123
class SplitOnGrid(Transform):
@@ -73,3 +75,153 @@ def get_params(self, image_size):
7375
)
7476

7577
return patch_size, steps
78+
79+
80+
class TileOnGrid(Randomizable, Transform):
81+
"""
82+
Tile the 2D image into patches on a grid and maintain a subset of it.
83+
This transform works only with np.ndarray inputs for 2D images.
84+
85+
Args:
86+
tile_count: number of tiles to extract, if None extracts all non-background tiles
87+
Defaults to ``None``.
88+
tile_size: size of the square tile
89+
Defaults to ``256``.
90+
step: step size
91+
Defaults to ``None`` (same as tile_size)
92+
random_offset: Randomize position of the grid, instead of starting from the top-left corner
93+
Defaults to ``False``.
94+
pad_full: pad image to the size evenly divisible by tile_size
95+
Defaults to ``False``.
96+
background_val: the background constant (e.g. 255 for white background)
97+
Defaults to ``255``.
98+
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
99+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
100+
Defaults to ``min`` (which assumes background is high value)
101+
102+
"""
103+
104+
def __init__(
105+
self,
106+
tile_count: Optional[int] = None,
107+
tile_size: int = 256,
108+
step: Optional[int] = None,
109+
random_offset: bool = False,
110+
pad_full: bool = False,
111+
background_val: int = 255,
112+
filter_mode: str = "min",
113+
):
114+
self.tile_count = tile_count
115+
self.tile_size = tile_size
116+
self.step = step
117+
self.random_offset = random_offset
118+
self.pad_full = pad_full
119+
self.background_val = background_val
120+
self.filter_mode = filter_mode
121+
122+
if self.step is None:
123+
self.step = self.tile_size # non-overlapping grid
124+
125+
self.offset = (0, 0)
126+
self.random_idxs = np.array((0,))
127+
128+
if self.filter_mode not in ["min", "max", "random"]:
129+
raise ValueError("Unsupported filter_mode, must be [min, max or random]: " + str(self.filter_mode))
130+
131+
def randomize(self, img_size: Sequence[int]) -> None:
132+
133+
c, h, w = img_size
134+
tile_step = cast(int, self.step)
135+
136+
self.offset = (0, 0)
137+
if self.random_offset:
138+
pad_h = h % self.tile_size
139+
pad_w = w % self.tile_size
140+
self.offset = (self.R.randint(pad_h) if pad_h > 0 else 0, self.R.randint(pad_w) if pad_w > 0 else 0)
141+
h = h - self.offset[0]
142+
w = w - self.offset[1]
143+
144+
if self.pad_full:
145+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
146+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
147+
h = h + pad_h
148+
w = w + pad_w
149+
150+
h_n = (h - self.tile_size + tile_step) // tile_step
151+
w_n = (w - self.tile_size + tile_step) // tile_step
152+
tile_total = h_n * w_n
153+
154+
if self.tile_count is not None and tile_total > self.tile_count:
155+
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False)
156+
else:
157+
self.random_idxs = np.array((0,))
158+
159+
def __call__(self, image: np.ndarray) -> np.ndarray:
160+
161+
# add random offset
162+
self.randomize(img_size=image.shape)
163+
tile_step = cast(int, self.step)
164+
165+
if self.random_offset and (self.offset[0] > 0 or self.offset[1] > 0):
166+
image = image[:, self.offset[0] :, self.offset[1] :]
167+
168+
# pad to full size, divisible by tile_size
169+
if self.pad_full:
170+
c, h, w = image.shape
171+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
172+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
173+
image = np.pad(
174+
image,
175+
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
176+
constant_values=self.background_val,
177+
)
178+
179+
# extact tiles
180+
xstep, ystep = tile_step, tile_step
181+
xsize, ysize = self.tile_size, self.tile_size
182+
clen, xlen, ylen = image.shape
183+
cstride, xstride, ystride = image.strides
184+
llw = as_strided(
185+
image,
186+
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
187+
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
188+
writeable=False,
189+
)
190+
image = llw.reshape(-1, clen, xsize, ysize)
191+
192+
# if keeping all patches
193+
if self.tile_count is None:
194+
# retain only patches with significant foreground content to speed up inference
195+
# FYI, this returns a variable number of tiles, so the batch_size must be 1 (per gpu), e.g during inference
196+
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
197+
if self.filter_mode == "min":
198+
# default, keep non-background tiles (small values)
199+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
200+
image = image[idxs.reshape(-1)]
201+
elif self.filter_mode == "max":
202+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
203+
image = image[idxs.reshape(-1)]
204+
205+
else:
206+
if len(image) > self.tile_count:
207+
208+
if self.filter_mode == "min":
209+
# default, keep non-background tiles (smallest values)
210+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count]
211+
image = image[idxs]
212+
elif self.filter_mode == "max":
213+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :]
214+
image = image[idxs]
215+
else:
216+
# random subset (more appropriate for WSIs without distinct background)
217+
if self.random_idxs is not None:
218+
image = image[self.random_idxs]
219+
220+
elif len(image) < self.tile_count:
221+
image = np.pad(
222+
image,
223+
[[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
224+
constant_values=self.background_val,
225+
)
226+
227+
return image

monai/apps/pathology/transforms/spatial/dictionary.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
12+
import copy
13+
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union
1314

15+
import numpy as np
1416
import torch
1517

1618
from monai.config import KeysCollection
17-
from monai.transforms.transform import MapTransform
19+
from monai.transforms.transform import MapTransform, Randomizable
1820

19-
from .array import SplitOnGrid
21+
from .array import SplitOnGrid, TileOnGrid
2022

21-
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
23+
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]
2224

2325

2426
class SplitOnGridd(MapTransform):
@@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
5355
return d
5456

5557

58+
class TileOnGridd(Randomizable, MapTransform):
59+
"""
60+
Tile the 2D image into patches on a grid and maintain a subset of it.
61+
This transform works only with np.ndarray inputs for 2D images.
62+
63+
Args:
64+
tile_count: number of tiles to extract, if None extracts all non-background tiles
65+
Defaults to ``None``.
66+
tile_size: size of the square tile
67+
Defaults to ``256``.
68+
step: step size
69+
Defaults to ``None`` (same as tile_size)
70+
random_offset: Randomize position of the grid, instead of starting from the top-left corner
71+
Defaults to ``False``.
72+
pad_full: pad image to the size evenly divisible by tile_size
73+
Defaults to ``False``.
74+
background_val: the background constant (e.g. 255 for white background)
75+
Defaults to ``255``.
76+
filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is more than tile_size,
77+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for random) subset
78+
Defaults to ``min`` (which assumes background is high value)
79+
80+
"""
81+
82+
def __init__(
83+
self,
84+
keys: KeysCollection,
85+
tile_count: Optional[int] = None,
86+
tile_size: int = 256,
87+
step: Optional[int] = None,
88+
random_offset: bool = False,
89+
pad_full: bool = False,
90+
background_val: int = 255,
91+
filter_mode: str = "min",
92+
allow_missing_keys: bool = False,
93+
return_list_of_dicts: bool = False,
94+
):
95+
super().__init__(keys, allow_missing_keys)
96+
97+
self.return_list_of_dicts = return_list_of_dicts
98+
self.seed = None
99+
100+
self.splitter = TileOnGrid(
101+
tile_count=tile_count,
102+
tile_size=tile_size,
103+
step=step,
104+
random_offset=random_offset,
105+
pad_full=pad_full,
106+
background_val=background_val,
107+
filter_mode=filter_mode,
108+
)
109+
110+
def randomize(self, data: Any = None) -> None:
111+
self.seed = self.R.randint(10000) # type: ignore
112+
113+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]:
114+
115+
self.randomize()
116+
117+
d = dict(data)
118+
for key in self.key_iterator(d):
119+
self.splitter.set_random_state(seed=self.seed) # same random seed for all keys
120+
d[key] = self.splitter(d[key])
121+
122+
if self.return_list_of_dicts:
123+
d_list = []
124+
for i in range(len(d[self.keys[0]])):
125+
d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()})
126+
d = d_list # type: ignore
127+
128+
return d
129+
130+
56131
SplitOnGridDict = SplitOnGridD = SplitOnGridd
132+
TileOnGridDict = TileOnGridD = TileOnGridd

0 commit comments

Comments
 (0)