Skip to content

Commit 940cf15

Browse files
committed
MIL component to extract patches
Signed-off-by: myron <[email protected]>
1 parent cfe64aa commit 940cf15

File tree

6 files changed

+478
-11
lines changed

6 files changed

+478
-11
lines changed

monai/apps/pathology/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .spatial.array import SplitOnGrid
13-
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .spatial.array import SplitOnGrid, TileOnGrid
13+
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
1414
from .stain.array import ExtractHEStains, NormalizeHEStains
1515
from .stain.dictionary import (
1616
ExtractHEStainsd,

monai/apps/pathology/transforms/spatial/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .array import SplitOnGrid
13-
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .array import SplitOnGrid, TileOnGrid
13+
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict

monai/apps/pathology/transforms/spatial/array.py

Lines changed: 161 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Sequence, Tuple, Union
1313

14+
import numpy as np
1415
import torch
16+
from numpy.lib.stride_tricks import as_strided
1517

16-
from monai.transforms.transform import Transform
18+
from monai.transforms.transform import Randomizable, Transform
1719

18-
__all__ = ["SplitOnGrid"]
20+
__all__ = ["SplitOnGrid", "TileOnGrid"]
1921

2022

2123
class SplitOnGrid(Transform):
@@ -73,3 +75,159 @@ def get_params(self, image_size):
7375
)
7476

7577
return patch_size, steps
78+
79+
80+
class TileOnGrid(Randomizable, Transform):
81+
"""
82+
Tile the 2D image into patches on a grid and maintain a subset of it.
83+
This transform works only with np.ndarray inputs for 2D images.
84+
85+
Args:
86+
tile_count: number of tiles to extract, if None Extract all non-background tiles
87+
Defaults to ``None``.
88+
tile_size: size of the square tile
89+
Defaults to ``256``.
90+
step: step size
91+
Defaults to None (same as tile_size)
92+
random_offset: Randomize position of tile grid, instead of starting from the top-left corner
93+
Defaults to ``False``.
94+
pad_full: pad image to the size evenly divisible by tile_size
95+
Defaults to ``False``.
96+
background_val: the background constant (e.g. 255 for white background)
97+
Defaults to ``255``.
98+
filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
99+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
100+
Defaults to ``min`` (which assumes background is white, high value)
101+
102+
"""
103+
104+
def __init__(
105+
self,
106+
tile_count: Optional[int] = None,
107+
tile_size: int = 256,
108+
step: Optional[int] = None,
109+
random_offset: bool = False,
110+
pad_full: bool = False,
111+
background_val: int = 255,
112+
filter_mode: Optional[str] = "min",
113+
):
114+
self.tile_count = tile_count
115+
self.tile_size = tile_size
116+
self.step = step
117+
self.random_offset = random_offset
118+
self.pad_full = pad_full
119+
self.background_val = background_val
120+
self.filter_mode = filter_mode
121+
122+
# self.tile_all = (self.tile_count is None)
123+
124+
# if self.tile_count is None:
125+
# self.tile_count = max((44 * 256 ** 2) // (tile_size ** 2), 1)
126+
127+
if self.step is None:
128+
self.step = self.tile_size # non-overlapping grid
129+
130+
self.offset = (0, 0)
131+
self.random_idxs = [0]
132+
133+
def randomize(self, img_size: Sequence[int]) -> None:
134+
135+
c, h, w = img_size
136+
# tile_count: int = self.tile_count # type: ignore
137+
tile_step: int = self.step # type: ignore
138+
139+
if self.random_offset:
140+
pad_h = h % self.tile_size
141+
pad_w = w % self.tile_size
142+
if pad_h > 0 and pad_w > 0:
143+
self.offset = (self.R.randint(pad_h), self.R.randint(pad_w))
144+
h = h - self.offset[0]
145+
w = w - self.offset[1]
146+
else:
147+
self.offset = (0, 0)
148+
149+
if self.pad_full:
150+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
151+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
152+
h = h + pad_h
153+
w = w + pad_w
154+
155+
h_n = (h - self.tile_size + tile_step) // tile_step
156+
w_n = (w - self.tile_size + tile_step) // tile_step
157+
tile_total = h_n * w_n
158+
159+
if self.tile_count is not None and tile_total > self.tile_count:
160+
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False) # type: ignore
161+
else:
162+
self.random_idxs = [0] # type: ignore
163+
164+
def __call__(self, image: np.ndarray) -> np.ndarray:
165+
166+
# add random offset
167+
self.randomize(img_size=image.shape)
168+
# tile_count: int = self.tile_count # type: ignore
169+
tile_step: int = self.step # type: ignore
170+
171+
if self.random_offset and self.offset is not None:
172+
image = image[:, self.offset[0] :, self.offset[1] :]
173+
174+
# pad to full size, divisible by tile_size
175+
if self.pad_full:
176+
c, h, w = image.shape
177+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
178+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
179+
image = np.pad(
180+
image,
181+
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
182+
constant_values=self.background_val,
183+
)
184+
185+
# extact tiles (new way)
186+
xstep, ystep = tile_step, tile_step
187+
xsize, ysize = self.tile_size, self.tile_size
188+
clen, xlen, ylen = image.shape
189+
cstride, xstride, ystride = image.strides
190+
llw = as_strided(
191+
image,
192+
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
193+
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
194+
writeable=False,
195+
)
196+
image = llw.reshape(-1, clen, xsize, ysize)
197+
198+
# if keep all patches
199+
if self.tile_count is None:
200+
# retain only patches with significant foreground content to speed up inference
201+
# FYI, this returns a variable number of tiles, so the batch_size much be 1 (per gpu). Used during inference
202+
thresh = 0.999 * 3 * self.background_val * self.tile_size * self.tile_size
203+
if self.filter_mode == "min":
204+
# default, keep non-background tiles (small values)
205+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
206+
image = image[idxs.reshape(-1)]
207+
elif self.filter_mode == "max":
208+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
209+
image = image[idxs.reshape(-1)]
210+
211+
else:
212+
if len(image) >= self.tile_count:
213+
214+
if self.filter_mode == "min":
215+
# default, keep non-background tiles (smallest values)
216+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count]
217+
image = image[idxs]
218+
elif self.filter_mode == "max":
219+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :]
220+
image = image[idxs]
221+
elif len(image) > self.tile_count:
222+
# random subset (more appropriate for WSIs without distinct background)
223+
if self.random_idxs is not None:
224+
image = image[self.random_idxs]
225+
226+
else:
227+
image = np.pad(
228+
image,
229+
[[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
230+
constant_values=self.background_val,
231+
)
232+
233+
return image

monai/apps/pathology/transforms/spatial/dictionary.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
12+
import copy
13+
from typing import Any, Dict, Hashable, List, Mapping, Optional, Tuple, Union
1314

15+
import numpy as np
1416
import torch
1517

1618
from monai.config import KeysCollection
17-
from monai.transforms.transform import MapTransform
19+
from monai.transforms.transform import MapTransform, Randomizable
1820

19-
from .array import SplitOnGrid
21+
from .array import SplitOnGrid, TileOnGrid
2022

21-
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
23+
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]
2224

2325

2426
class SplitOnGridd(MapTransform):
@@ -53,4 +55,78 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
5355
return d
5456

5557

58+
class TileOnGridd(Randomizable, MapTransform):
59+
"""
60+
Tile the 2D image into patches on a grid and maintain a subset of it.
61+
This transform works only with np.ndarray inputs for 2D images.
62+
63+
Args:
64+
tile_count: number of tiles to extract, if None Extract all non-background tiles
65+
Defaults to ``None``.
66+
tile_size: size of the square tile
67+
Defaults to ``256``.
68+
step: step size
69+
Defaults to None (same as tile_size)
70+
random_offset: Randomize position of tile grid, instead of starting from the top-left corner
71+
Defaults to ``False``.
72+
pad_full: pad image to the size evenly divisible by tile_size
73+
Defaults to ``False``.
74+
background_val: the background constant (e.g. 255 for white background)
75+
Defaults to ``255``.
76+
filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
77+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
78+
Defaults to ``min`` (which assumes background is white, high value)
79+
80+
"""
81+
82+
def __init__(
83+
self,
84+
keys: KeysCollection,
85+
tile_count: Optional[int] = None,
86+
tile_size: int = 256,
87+
step: Optional[int] = None,
88+
random_offset: bool = False,
89+
pad_full: bool = False,
90+
background_val: int = 255,
91+
filter_mode: Optional[str] = "min",
92+
allow_missing_keys: bool = False,
93+
return_list_of_dicts: bool = False,
94+
):
95+
super().__init__(keys, allow_missing_keys)
96+
97+
self.return_list_of_dicts = return_list_of_dicts
98+
self.seed = None
99+
100+
self.splitter = TileOnGrid(
101+
tile_count=tile_count,
102+
tile_size=tile_size,
103+
step=step,
104+
random_offset=random_offset,
105+
pad_full=pad_full,
106+
background_val=background_val,
107+
filter_mode=filter_mode,
108+
)
109+
110+
def randomize(self, data: Any = None) -> None:
111+
self.seed = self.R.randint(10000) # type: ignore
112+
113+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Union[Dict[Hashable, np.ndarray], List[Dict]]:
114+
115+
self.randomize()
116+
117+
d = dict(data)
118+
for key in self.key_iterator(d):
119+
self.splitter.set_random_state(seed=self.seed) # same random seed for all keys
120+
d[key] = self.splitter(d[key])
121+
122+
if self.return_list_of_dicts:
123+
d_list = []
124+
for i in range(len(d[self.keys[0]])):
125+
d_list.append({k: d[k][i] if k in self.keys else copy.deepcopy(d[k]) for k in d.keys()})
126+
d = d_list # type: ignore
127+
128+
return d
129+
130+
56131
SplitOnGridDict = SplitOnGridD = SplitOnGridd
132+
TileOnGridDict = TileOnGridD = TileOnGridd

0 commit comments

Comments
 (0)