Skip to content

Commit 463b159

Browse files
committed
MIL component to extract patches
Signed-off-by: myron <[email protected]>
1 parent cfe64aa commit 463b159

File tree

6 files changed

+498
-11
lines changed

6 files changed

+498
-11
lines changed

monai/apps/pathology/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .spatial.array import SplitOnGrid
13-
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .spatial.array import SplitOnGrid, TileOnGrid
13+
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
1414
from .stain.array import ExtractHEStains, NormalizeHEStains
1515
from .stain.dictionary import (
1616
ExtractHEStainsd,

monai/apps/pathology/transforms/spatial/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .array import SplitOnGrid
13-
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .array import SplitOnGrid, TileOnGrid
13+
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict

monai/apps/pathology/transforms/spatial/array.py

Lines changed: 162 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Sequence, Tuple, Union
1313

14+
import numpy as np
1415
import torch
16+
from numpy.lib.stride_tricks import as_strided
1517

16-
from monai.transforms.transform import Transform
18+
from monai.transforms.transform import Randomizable, Transform
1719

18-
__all__ = ["SplitOnGrid"]
20+
__all__ = ["SplitOnGrid", "TileOnGrid"]
1921

2022

2123
class SplitOnGrid(Transform):
@@ -73,3 +75,160 @@ def get_params(self, image_size):
7375
)
7476

7577
return patch_size, steps
78+
79+
80+
class TileOnGrid(Randomizable, Transform):
81+
"""
82+
Tile the 2D image into patches on a grid and maintain a subset of it.
83+
This transform works only with np.ndarray inputs for 2D images.
84+
85+
Args:
86+
tile_count: number of tiles to extract
87+
tile_size: size of the square tile
88+
Defaults to ``256``.
89+
tile_step: step size
90+
Defaults to None (same as tile_size)
91+
tile_all: Extract all non-background tiles, instead of tile_count.
92+
Defaults to ``False``.
93+
tile_random_offset: Randomize position of tile grid, instead of starting from the top-left corner
94+
Defaults to ``False``.
95+
tile_pad_full: pad image to the size evenly divisible by tile_size
96+
Defaults to ``False``.
97+
tile_background_val: the background constant (e.g. 255 for white background)
98+
Defaults to ``255``.
99+
tile_filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
100+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
101+
Defaults to ``min`` (which assumes background is white, high value)
102+
103+
"""
104+
105+
def __init__(
106+
self,
107+
tile_count: Optional[int] = None,
108+
tile_size: int = 256,
109+
tile_step: Optional[int] = None,
110+
tile_all: bool = False,
111+
tile_random_offset: bool = False,
112+
tile_pad_full: bool = False,
113+
tile_background_val: int = 255,
114+
tile_filter_mode: Optional[str] = "min",
115+
):
116+
self.tile_count = tile_count
117+
self.tile_size = tile_size
118+
self.tile_step = tile_step
119+
self.tile_all = tile_all
120+
self.tile_random_offset = tile_random_offset
121+
self.tile_pad_full = tile_pad_full
122+
self.tile_background_val = tile_background_val
123+
self.tile_filter_mode = tile_filter_mode
124+
125+
if self.tile_count is None:
126+
self.tile_count = max((44 * 256 ** 2) // (tile_size ** 2), 1)
127+
128+
if self.tile_step is None:
129+
self.tile_step = self.tile_size # non-overlapping grid
130+
131+
self.random_offset = (0, 0)
132+
self.random_idxs = [0]
133+
134+
def randomize(self, img_size: Sequence[int]) -> None:
135+
136+
c, h, w = img_size
137+
tile_count: int = self.tile_count # type: ignore
138+
tile_step: int = self.tile_step # type: ignore
139+
140+
if self.tile_random_offset:
141+
pad_h = h % self.tile_size
142+
pad_w = w % self.tile_size
143+
if pad_h > 0 and pad_w > 0:
144+
self.random_offset = (self.R.randint(pad_h), self.R.randint(pad_w))
145+
h = h - self.random_offset[0]
146+
w = w - self.random_offset[1]
147+
else:
148+
self.random_offset = (0, 0)
149+
150+
if self.tile_pad_full:
151+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
152+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
153+
h = h + pad_h
154+
w = w + pad_w
155+
156+
h_n = (h - self.tile_size + tile_step) // tile_step
157+
w_n = (w - self.tile_size + tile_step) // tile_step
158+
tile_total = h_n * w_n
159+
160+
if tile_total > tile_count:
161+
self.random_idxs = self.R.choice(range(tile_total), tile_count, replace=False) # type: ignore
162+
else:
163+
self.random_idxs = [0] # type: ignore
164+
165+
def __call__(self, image: np.ndarray) -> np.ndarray:
166+
167+
# add random offset
168+
self.randomize(img_size=image.shape)
169+
tile_count: int = self.tile_count # type: ignore
170+
tile_step: int = self.tile_step # type: ignore
171+
172+
if self.tile_random_offset and self.random_offset is not None:
173+
image = image[:, self.random_offset[0] :, self.random_offset[1] :]
174+
175+
# pad to full size, divisible by tile_size
176+
if self.tile_pad_full:
177+
c, h, w = image.shape
178+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
179+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
180+
image = np.pad(
181+
image,
182+
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
183+
constant_values=self.tile_background_val,
184+
)
185+
186+
# extact tiles (new way)
187+
xstep, ystep = tile_step, tile_step
188+
xsize, ysize = self.tile_size, self.tile_size
189+
clen, xlen, ylen = image.shape
190+
cstride, xstride, ystride = image.strides
191+
llw = as_strided(
192+
image,
193+
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
194+
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
195+
writeable=False,
196+
)
197+
image = llw.reshape(-1, clen, xsize, ysize)
198+
199+
# if keep all patches
200+
if self.tile_all:
201+
# retain only patches with significant foreground content to speed up inference
202+
# FYI, this returns a variable number of tiles, so the batch_size much be 1 (per gpu). Used during inference
203+
thresh = 0.999 * 3 * self.tile_background_val * self.tile_size * self.tile_size
204+
if self.tile_filter_mode == "min":
205+
# default, keep non-background tiles (small values)
206+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
207+
image = image[idxs.reshape(-1)]
208+
elif self.tile_filter_mode == "max":
209+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
210+
image = image[idxs.reshape(-1)]
211+
212+
else:
213+
if len(image) >= tile_count:
214+
215+
if self.tile_filter_mode == "min":
216+
# default, keep non-background tiles (smallest values)
217+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[:tile_count]
218+
image = image[idxs]
219+
elif self.tile_filter_mode == "max":
220+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-tile_count:]
221+
image = image[idxs]
222+
elif len(image) > tile_count:
223+
# random subset (more appropriate for WSIs without distinct background)
224+
if self.random_idxs is not None:
225+
image = image[self.random_idxs]
226+
227+
else:
228+
image = np.pad(
229+
image,
230+
[[0, tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
231+
constant_values=self.tile_background_val,
232+
)
233+
234+
return image

monai/apps/pathology/transforms/spatial/dictionary.py

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
12+
import copy
13+
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union
1314

15+
import numpy as np
1416
import torch
1517

1618
from monai.config import KeysCollection
17-
from monai.transforms.transform import MapTransform
19+
from monai.transforms.transform import MapTransform, Randomizable
1820

19-
from .array import SplitOnGrid
21+
from .array import SplitOnGrid, TileOnGrid
2022

21-
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
23+
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]
2224

2325

2426
class SplitOnGridd(MapTransform):
@@ -53,4 +55,83 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
5355
return d
5456

5557

58+
class TileOnGridd(Randomizable, MapTransform):
59+
"""
60+
Tile the 2D image into patches on a grid and maintain a subset of it.
61+
This transform works only with np.ndarray inputs for 2D images.
62+
63+
Args:
64+
tile_count: number of tiles to extract
65+
tile_size: size of the square tile
66+
Defaults to ``256``.
67+
tile_step: step size
68+
Defaults to None (same as tile_size)
69+
tile_all: Extract all non-background tiles, instead of tile_count.
70+
Defaults to ``False``.
71+
tile_random_offset: Randomize position of tile grid, instead of starting from the top-left corner
72+
Defaults to ``False``.
73+
tile_pad_full: pad image to the size evenly divisible by tile_size
74+
Defaults to ``False``.
75+
tile_background_val: the background constant (e.g. 255 for white background)
76+
Defaults to ``255``.
77+
tile_filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
78+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
79+
Defaults to ``min`` (which assumes background is white, high value)
80+
return_list_of_dicts: return each tile in a separate dictionary, as a list of dicts
81+
Defaults to ``False``
82+
83+
"""
84+
85+
def __init__(
86+
self,
87+
keys: KeysCollection,
88+
tile_count: Optional[int] = None,
89+
tile_size: int = 256,
90+
tile_step: Optional[int] = None,
91+
tile_all: bool = False,
92+
tile_random_offset: bool = False,
93+
tile_pad_full: bool = False,
94+
tile_background_val: int = 255,
95+
tile_filter_mode: str = "min",
96+
allow_missing_keys: bool = False,
97+
return_list_of_dicts: bool = False,
98+
):
99+
super().__init__(keys, allow_missing_keys)
100+
101+
self.return_list_of_dicts = return_list_of_dicts
102+
self.seed = None
103+
104+
self.splitter = TileOnGrid(
105+
tile_count=tile_count,
106+
tile_size=tile_size,
107+
tile_step=tile_step,
108+
tile_all=tile_all,
109+
tile_random_offset=tile_random_offset,
110+
tile_pad_full=tile_pad_full,
111+
tile_background_val=tile_background_val,
112+
tile_filter_mode=tile_filter_mode,
113+
)
114+
115+
def randomize(self, data: Any = None) -> None:
116+
self.seed = self.R.randint(10000) # type: ignore
117+
118+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]:
119+
120+
self.randomize()
121+
122+
d = dict(data)
123+
for key in self.key_iterator(d):
124+
self.splitter.set_random_state(seed=self.seed) # same random seed for all keys
125+
d[key] = self.splitter(d[key])
126+
127+
if self.return_list_of_dicts:
128+
d_list = []
129+
for i in range(len(d[self.keys[0]])):
130+
d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()})
131+
d = d_list # type: ignore
132+
133+
return d
134+
135+
56136
SplitOnGridDict = SplitOnGridD = SplitOnGridd
137+
TileOnGridDict = TileOnGridD = TileOnGridd

0 commit comments

Comments
 (0)