Skip to content

Commit ee7f4fe

Browse files
committed
MIL component to extract patches
Signed-off-by: myron <[email protected]>
1 parent f08b625 commit ee7f4fe

File tree

6 files changed

+482
-11
lines changed

6 files changed

+482
-11
lines changed

monai/apps/pathology/transforms/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .spatial.array import SplitOnGrid
13-
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .spatial.array import SplitOnGrid, TileOnGrid
13+
from .spatial.dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict
1414
from .stain.array import ExtractHEStains, NormalizeHEStains
1515
from .stain.dictionary import (
1616
ExtractHEStainsd,

monai/apps/pathology/transforms/spatial/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from .array import SplitOnGrid
13-
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict
12+
from .array import SplitOnGrid, TileOnGrid
13+
from .dictionary import SplitOnGridd, SplitOnGridD, SplitOnGridDict, TileOnGridd, TileOnGridD, TileOnGridDict

monai/apps/pathology/transforms/spatial/array.py

Lines changed: 155 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Optional, Tuple, Union
12+
from typing import Optional, Sequence, Tuple, Union
1313

14+
import numpy as np
1415
import torch
16+
from numpy.lib.stride_tricks import as_strided
1517

16-
from monai.transforms.transform import Transform
18+
from monai.transforms.transform import Randomizable, Transform
1719

18-
__all__ = ["SplitOnGrid"]
20+
__all__ = ["SplitOnGrid", "TileOnGrid"]
1921

2022

2123
class SplitOnGrid(Transform):
@@ -73,3 +75,153 @@ def get_params(self, image_size):
7375
)
7476

7577
return patch_size, steps
78+
79+
80+
class TileOnGrid(Randomizable, Transform):
81+
"""
82+
Tile the image into patches on a grid and maintain a subset of it.
83+
This transform works only with np.ndarray inputs.
84+
85+
Args:
86+
tile_count: number of tiles to extract
87+
tile_size: size of the square tile
88+
Defaults to ``256``.
89+
tile_step: step size
90+
Defaults to None (same as tile_size)
91+
tile_all: Extract all non-background tiles, instead of tile_count.
92+
Defaults to ``False``.
93+
tile_random_offset: Randomize position of tile grid, instead of starting from the top-left corner
94+
Defaults to ``False``.
95+
tile_pad_full: pad image to the size evenly divisible by tile_size
96+
Defaults to ``False``.
97+
tile_background_val: the background constant (e.g. 255 for white background)
98+
Defaults to ``255``.
99+
tile_filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
100+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
101+
Defaults to ``min`` (which assumes background is white, high value)
102+
103+
"""
104+
105+
def __init__(
106+
self,
107+
tile_count: int = None,
108+
tile_size: int = 256,
109+
tile_step: int = None,
110+
tile_all: bool = False,
111+
tile_random_offset: bool = False,
112+
tile_pad_full: bool = False,
113+
tile_background_val: int = 255,
114+
tile_filter_mode: str = "min", # None, min, max
115+
):
116+
self.tile_count = tile_count
117+
self.tile_size = tile_size
118+
self.tile_step = tile_step
119+
self.tile_all = tile_all
120+
self.tile_random_offset = tile_random_offset
121+
self.tile_pad_full = tile_pad_full
122+
self.tile_background_val = tile_background_val
123+
self.tile_filter_mode = tile_filter_mode
124+
125+
if self.tile_step is None:
126+
self.tile_step = self.tile_size # non-overlapping grid
127+
128+
self.random_offset = None
129+
self.random_idxs = None
130+
131+
def randomize(self, img_size: Sequence[int] = None) -> None:
132+
133+
c, h, w = img_size
134+
135+
if img_size is not None:
136+
pad_h = h % self.tile_size
137+
pad_w = w % self.tile_size
138+
if pad_h > 0 and pad_w > 0:
139+
self.random_offset = (self.R.randint(pad_h), self.R.randint(pad_w))
140+
h = h - self.random_offset[0]
141+
w = w - self.random_offset[1]
142+
else:
143+
self.random_offset = None
144+
145+
if self.tile_pad_full:
146+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
147+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
148+
h = h + pad_h
149+
w = w + pad_w
150+
151+
h_n = (h - self.tile_size + self.tile_step) // self.tile_step
152+
w_n = (w - self.tile_size + self.tile_step) // self.tile_step
153+
tile_total = h_n * w_n
154+
155+
if tile_total > self.tile_count:
156+
self.random_idxs = self.R.choice(range(tile_total), self.tile_count, replace=False)
157+
else:
158+
self.random_idxs = None
159+
160+
def __call__(self, image: np.ndarray) -> np.ndarray:
161+
162+
# add random offset
163+
self.randomize(img_size=image.shape)
164+
165+
if self.tile_random_offset and self.random_offset is not None:
166+
image = image[:, self.random_offset[0] :, self.random_offset[1] :]
167+
168+
# pad to full size, divisible by tile_size
169+
if self.tile_pad_full:
170+
c, h, w = image.shape
171+
pad_h = (self.tile_size - h % self.tile_size) % self.tile_size
172+
pad_w = (self.tile_size - w % self.tile_size) % self.tile_size
173+
image = np.pad(
174+
image,
175+
[[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]],
176+
constant_values=self.tile_background_val,
177+
)
178+
179+
# extact tiles (new way)
180+
xstep, ystep = self.tile_step, self.tile_step
181+
xsize, ysize = self.tile_size, self.tile_size
182+
clen, xlen, ylen = image.shape
183+
cstride, xstride, ystride = image.strides
184+
llw = as_strided(
185+
image,
186+
shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize),
187+
strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride),
188+
writeable=False,
189+
)
190+
image = llw.reshape(-1, clen, xsize, ysize)
191+
192+
# if keep all patches
193+
if self.tile_all:
194+
# retain only patches with significant foreground content to speed up inference
195+
# FYI, this returns a variable number of tiles, so the batch_size much be 1 (per gpu). Used during inference
196+
thresh = 0.999 * 3 * self.tile_background_val * self.tile_size * self.tile_size
197+
if self.tile_filter_mode == "min":
198+
# default, keep non-background tiles (small values)
199+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh)
200+
image = image[idxs.reshape(-1)]
201+
elif self.tile_filter_mode == "max":
202+
idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh)
203+
image = image[idxs.reshape(-1)]
204+
205+
else:
206+
if len(image) >= self.tile_count:
207+
208+
if self.tile_filter_mode == "min":
209+
# default, keep non-background tiles (smallest values)
210+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[: self.tile_count]
211+
image = image[idxs]
212+
elif self.tile_filter_mode == "max":
213+
idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-self.tile_count :]
214+
image = image[idxs]
215+
elif len(image) > self.tile_count:
216+
# random subset (more appropriate for WSIs without distinct background)
217+
if self.random_idxs is not None:
218+
image = image[self.random_idxs]
219+
220+
else:
221+
image = np.pad(
222+
image,
223+
[[0, self.tile_count - len(image)], [0, 0], [0, 0], [0, 0]],
224+
constant_values=self.tile_background_val,
225+
)
226+
227+
return image

monai/apps/pathology/transforms/spatial/dictionary.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import Dict, Hashable, Mapping, Optional, Tuple, Union
12+
import copy
13+
from typing import Any, Dict, Hashable, Mapping, Optional, Tuple, Union
1314

15+
import numpy as np
1416
import torch
1517

1618
from monai.config import KeysCollection
17-
from monai.transforms.transform import MapTransform
19+
from monai.transforms.transform import MapTransform, Randomizable
1820

19-
from .array import SplitOnGrid
21+
from .array import SplitOnGrid, TileOnGrid
2022

21-
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict"]
23+
__all__ = ["SplitOnGridd", "SplitOnGridD", "SplitOnGridDict", "TileOnGridd", "TileOnGridD", "TileOnGridDict"]
2224

2325

2426
class SplitOnGridd(MapTransform):
@@ -53,4 +55,86 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
5355
return d
5456

5557

58+
class TileOnGridd(Randomizable, MapTransform):
59+
"""
60+
Tile the image into patches on a grid and maintain a subset of it.
61+
This transform works only with np.ndarray inputs.
62+
63+
Args:
64+
tile_count: number of tiles to extract
65+
tile_size: size of the square tile
66+
Defaults to ``256``.
67+
tile_step: step size
68+
Defaults to None (same as tile_size)
69+
tile_all: Extract all non-background tiles, instead of tile_count.
70+
Defaults to ``False``.
71+
tile_random_offset: Randomize position of tile grid, instead of starting from the top-left corner
72+
Defaults to ``False``.
73+
tile_pad_full: pad image to the size evenly divisible by tile_size
74+
Defaults to ``False``.
75+
tile_background_val: the background constant (e.g. 255 for white background)
76+
Defaults to ``255``.
77+
tile_filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size,
78+
then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset
79+
Defaults to ``min`` (which assumes background is white, high value)
80+
return_list_of_dicts: return each tile in a separate dictionary, as a list of dicts
81+
Defaults to ``False``
82+
83+
"""
84+
85+
def __init__(
86+
self,
87+
keys: KeysCollection,
88+
tile_count: int = None,
89+
tile_size: int = 256,
90+
tile_step: int = None,
91+
tile_all: bool = False,
92+
tile_random_offset: bool = False,
93+
tile_pad_full: bool = False,
94+
tile_background_val: int = 255,
95+
tile_filter_mode: str = "min",
96+
allow_missing_keys: bool = False,
97+
return_list_of_dicts: bool = False,
98+
):
99+
super().__init__(keys, allow_missing_keys)
100+
101+
self.return_list_of_dicts = return_list_of_dicts
102+
self.seed = None
103+
104+
self.splitter = TileOnGrid(
105+
tile_count=tile_count,
106+
tile_size=tile_size,
107+
tile_step=tile_step,
108+
tile_all=tile_all,
109+
tile_random_offset=tile_random_offset,
110+
tile_pad_full=tile_pad_full,
111+
tile_background_val=tile_background_val,
112+
tile_filter_mode=tile_filter_mode,
113+
)
114+
115+
def randomize(self, data: Any = None) -> None:
116+
self.seed = self.R.randint(10000)
117+
118+
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
119+
120+
self.randomize()
121+
122+
d = dict(data)
123+
for key in self.key_iterator(d):
124+
self.splitter.set_random_state(seed=self.seed) # same random seed for all keys
125+
d[key] = self.splitter(d[key])
126+
127+
if self.return_list_of_dicts:
128+
d_list = []
129+
for i in range(len(d[self.keys[0]])):
130+
context = {}
131+
for i, (k, v) in enumerate(d.items()):
132+
context.update({k: v[i] if k in self.keys else copy.deepcopy(v)})
133+
d_list.append(context)
134+
d = d_list
135+
136+
return d
137+
138+
56139
SplitOnGridDict = SplitOnGridD = SplitOnGridd
140+
TileOnGridDict = TileOnGridD = TileOnGridd

0 commit comments

Comments
 (0)