|
9 | 9 | # See the License for the specific language governing permissions and
|
10 | 10 | # limitations under the License.
|
11 | 11 |
|
12 |
| -from typing import Optional, Tuple, Union |
| 12 | +from typing import Optional, Sequence, Tuple, Union |
13 | 13 |
|
| 14 | +import numpy as np |
14 | 15 | import torch
|
| 16 | +from numpy.lib.stride_tricks import as_strided |
15 | 17 |
|
16 |
| -from monai.transforms.transform import Transform |
| 18 | +from monai.transforms.transform import Randomizable, Transform |
17 | 19 |
|
18 |
| -__all__ = ["SplitOnGrid"] |
| 20 | +__all__ = ["SplitOnGrid", "TileOnGrid"] |
19 | 21 |
|
20 | 22 |
|
21 | 23 | class SplitOnGrid(Transform):
|
@@ -73,3 +75,160 @@ def get_params(self, image_size):
|
73 | 75 | )
|
74 | 76 |
|
75 | 77 | return patch_size, steps
|
| 78 | + |
| 79 | + |
| 80 | +class TileOnGrid(Randomizable, Transform): |
| 81 | + """ |
| 82 | + Tile the 2D image into patches on a grid and maintain a subset of it. |
| 83 | + This transform works only with np.ndarray inputs for 2D images. |
| 84 | +
|
| 85 | + Args: |
| 86 | + tile_count: number of tiles to extract |
| 87 | + tile_size: size of the square tile |
| 88 | + Defaults to ``256``. |
| 89 | + tile_step: step size |
| 90 | + Defaults to None (same as tile_size) |
| 91 | + tile_all: Extract all non-background tiles, instead of tile_count. |
| 92 | + Defaults to ``False``. |
| 93 | + tile_random_offset: Randomize position of tile grid, instead of starting from the top-left corner |
| 94 | + Defaults to ``False``. |
| 95 | + tile_pad_full: pad image to the size evenly divisible by tile_size |
| 96 | + Defaults to ``False``. |
| 97 | + tile_background_val: the background constant (e.g. 255 for white background) |
| 98 | + Defaults to ``255``. |
| 99 | + tile_filter_mode: mode must be in ["min", "max", None]. If total number of tiles is more then tile_size, |
| 100 | + then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for None) subset |
| 101 | + Defaults to ``min`` (which assumes background is white, high value) |
| 102 | +
|
| 103 | + """ |
| 104 | + |
| 105 | + def __init__( |
| 106 | + self, |
| 107 | + tile_count: Optional[int] = None, |
| 108 | + tile_size: int = 256, |
| 109 | + tile_step: Optional[int] = None, |
| 110 | + tile_all: bool = False, |
| 111 | + tile_random_offset: bool = False, |
| 112 | + tile_pad_full: bool = False, |
| 113 | + tile_background_val: int = 255, |
| 114 | + tile_filter_mode: Optional[str] = "min", |
| 115 | + ): |
| 116 | + self.tile_count = tile_count |
| 117 | + self.tile_size = tile_size |
| 118 | + self.tile_step = tile_step |
| 119 | + self.tile_all = tile_all |
| 120 | + self.tile_random_offset = tile_random_offset |
| 121 | + self.tile_pad_full = tile_pad_full |
| 122 | + self.tile_background_val = tile_background_val |
| 123 | + self.tile_filter_mode = tile_filter_mode |
| 124 | + |
| 125 | + if self.tile_count is None: |
| 126 | + self.tile_count = max((44 * 256 ** 2) // (tile_size ** 2), 1) |
| 127 | + |
| 128 | + if self.tile_step is None: |
| 129 | + self.tile_step = self.tile_size # non-overlapping grid |
| 130 | + |
| 131 | + self.random_offset = (0, 0) |
| 132 | + self.random_idxs = [0] |
| 133 | + |
| 134 | + def randomize(self, img_size: Sequence[int]) -> None: |
| 135 | + |
| 136 | + c, h, w = img_size |
| 137 | + tile_count: int = self.tile_count # type: ignore |
| 138 | + tile_step: int = self.tile_step # type: ignore |
| 139 | + |
| 140 | + if self.tile_random_offset: |
| 141 | + pad_h = h % self.tile_size |
| 142 | + pad_w = w % self.tile_size |
| 143 | + if pad_h > 0 and pad_w > 0: |
| 144 | + self.random_offset = (self.R.randint(pad_h), self.R.randint(pad_w)) |
| 145 | + h = h - self.random_offset[0] |
| 146 | + w = w - self.random_offset[1] |
| 147 | + else: |
| 148 | + self.random_offset = (0, 0) |
| 149 | + |
| 150 | + if self.tile_pad_full: |
| 151 | + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size |
| 152 | + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size |
| 153 | + h = h + pad_h |
| 154 | + w = w + pad_w |
| 155 | + |
| 156 | + h_n = (h - self.tile_size + tile_step) // tile_step |
| 157 | + w_n = (w - self.tile_size + tile_step) // tile_step |
| 158 | + tile_total = h_n * w_n |
| 159 | + |
| 160 | + if tile_total > tile_count: |
| 161 | + self.random_idxs = self.R.choice(range(tile_total), tile_count, replace=False) # type: ignore |
| 162 | + else: |
| 163 | + self.random_idxs = [0] # type: ignore |
| 164 | + |
| 165 | + def __call__(self, image: np.ndarray) -> np.ndarray: |
| 166 | + |
| 167 | + # add random offset |
| 168 | + self.randomize(img_size=image.shape) |
| 169 | + tile_count: int = self.tile_count # type: ignore |
| 170 | + tile_step: int = self.tile_step # type: ignore |
| 171 | + |
| 172 | + if self.tile_random_offset and self.random_offset is not None: |
| 173 | + image = image[:, self.random_offset[0] :, self.random_offset[1] :] |
| 174 | + |
| 175 | + # pad to full size, divisible by tile_size |
| 176 | + if self.tile_pad_full: |
| 177 | + c, h, w = image.shape |
| 178 | + pad_h = (self.tile_size - h % self.tile_size) % self.tile_size |
| 179 | + pad_w = (self.tile_size - w % self.tile_size) % self.tile_size |
| 180 | + image = np.pad( |
| 181 | + image, |
| 182 | + [[0, 0], [pad_h // 2, pad_h - pad_h // 2], [pad_w // 2, pad_w - pad_w // 2]], |
| 183 | + constant_values=self.tile_background_val, |
| 184 | + ) |
| 185 | + |
| 186 | + # extact tiles (new way) |
| 187 | + xstep, ystep = tile_step, tile_step |
| 188 | + xsize, ysize = self.tile_size, self.tile_size |
| 189 | + clen, xlen, ylen = image.shape |
| 190 | + cstride, xstride, ystride = image.strides |
| 191 | + llw = as_strided( |
| 192 | + image, |
| 193 | + shape=((xlen - xsize) // xstep + 1, (ylen - ysize) // ystep + 1, clen, xsize, ysize), |
| 194 | + strides=(xstride * xstep, ystride * ystep, cstride, xstride, ystride), |
| 195 | + writeable=False, |
| 196 | + ) |
| 197 | + image = llw.reshape(-1, clen, xsize, ysize) |
| 198 | + |
| 199 | + # if keep all patches |
| 200 | + if self.tile_all: |
| 201 | + # retain only patches with significant foreground content to speed up inference |
| 202 | + # FYI, this returns a variable number of tiles, so the batch_size much be 1 (per gpu). Used during inference |
| 203 | + thresh = 0.999 * 3 * self.tile_background_val * self.tile_size * self.tile_size |
| 204 | + if self.tile_filter_mode == "min": |
| 205 | + # default, keep non-background tiles (small values) |
| 206 | + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) < thresh) |
| 207 | + image = image[idxs.reshape(-1)] |
| 208 | + elif self.tile_filter_mode == "max": |
| 209 | + idxs = np.argwhere(image.sum(axis=(1, 2, 3)) >= thresh) |
| 210 | + image = image[idxs.reshape(-1)] |
| 211 | + |
| 212 | + else: |
| 213 | + if len(image) >= tile_count: |
| 214 | + |
| 215 | + if self.tile_filter_mode == "min": |
| 216 | + # default, keep non-background tiles (smallest values) |
| 217 | + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[:tile_count] |
| 218 | + image = image[idxs] |
| 219 | + elif self.tile_filter_mode == "max": |
| 220 | + idxs = np.argsort(image.sum(axis=(1, 2, 3)))[-tile_count:] |
| 221 | + image = image[idxs] |
| 222 | + elif len(image) > tile_count: |
| 223 | + # random subset (more appropriate for WSIs without distinct background) |
| 224 | + if self.random_idxs is not None: |
| 225 | + image = image[self.random_idxs] |
| 226 | + |
| 227 | + else: |
| 228 | + image = np.pad( |
| 229 | + image, |
| 230 | + [[0, tile_count - len(image)], [0, 0], [0, 0], [0, 0]], |
| 231 | + constant_values=self.tile_background_val, |
| 232 | + ) |
| 233 | + |
| 234 | + return image |
0 commit comments