|
15 | 15 | import torch.nn.functional as F
|
16 | 16 | from torch import nn
|
17 | 17 |
|
| 18 | +from monai.utils import optional_import |
| 19 | + |
| 20 | +rearrange, _ = optional_import("einops", name="rearrange") |
| 21 | + |
18 | 22 |
|
19 | 23 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
|
20 | 24 | """
|
@@ -126,3 +130,162 @@ def add_decomposed_rel_pos(
|
126 | 130 | ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
|
127 | 131 |
|
128 | 132 | return attn
|
| 133 | + |
| 134 | + |
| 135 | +def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: |
| 136 | + """ |
| 137 | + Partition into non-overlapping windows with padding if needed. Support 2D and 3D. |
| 138 | + Args: |
| 139 | + x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) |
| 140 | + input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) |
| 141 | + window_size (int): window size |
| 142 | +
|
| 143 | + Returns: |
| 144 | + windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. |
| 145 | + with n = 1...len(input_size) and window_size_i == window_size. |
| 146 | + (S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) |
| 147 | + """ |
| 148 | + if x.shape[1] != int(torch.prod(torch.tensor(input_size))): |
| 149 | + raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") |
| 150 | + |
| 151 | + if len(input_size) == 2: |
| 152 | + x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) |
| 153 | + x, pad_hw = window_partition_2d(x, window_size) |
| 154 | + x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) |
| 155 | + return x, pad_hw |
| 156 | + elif len(input_size) == 3: |
| 157 | + x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) |
| 158 | + x, pad_hwd = window_partition_3d(x, window_size) |
| 159 | + x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) |
| 160 | + return x, pad_hwd |
| 161 | + else: |
| 162 | + raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") |
| 163 | + |
| 164 | + |
| 165 | +def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| 166 | + """ |
| 167 | + Partition into non-overlapping windows with padding if needed. Support only 2D. |
| 168 | + Args: |
| 169 | + x (tensor): input tokens with [B, H, W, C]. |
| 170 | + window_size (int): window size. |
| 171 | +
|
| 172 | + Returns: |
| 173 | + windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| 174 | + (Hp, Wp): padded height and width before partition |
| 175 | + """ |
| 176 | + batch, h, w, c = x.shape |
| 177 | + |
| 178 | + pad_h = (window_size - h % window_size) % window_size |
| 179 | + pad_w = (window_size - w % window_size) % window_size |
| 180 | + if pad_h > 0 or pad_w > 0: |
| 181 | + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| 182 | + hp, wp = h + pad_h, w + pad_w |
| 183 | + |
| 184 | + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) |
| 185 | + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) |
| 186 | + return windows, (hp, wp) |
| 187 | + |
| 188 | + |
| 189 | +def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: |
| 190 | + """ |
| 191 | + Partition into non-overlapping windows with padding if needed. 3d implementation. |
| 192 | + Args: |
| 193 | + x (tensor): input tokens with [B, H, W, D, C]. |
| 194 | + window_size (int): window size. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. |
| 198 | + (Hp, Wp, Dp): padded height, width and depth before partition |
| 199 | + """ |
| 200 | + batch, h, w, d, c = x.shape |
| 201 | + |
| 202 | + pad_h = (window_size - h % window_size) % window_size |
| 203 | + pad_w = (window_size - w % window_size) % window_size |
| 204 | + pad_d = (window_size - d % window_size) % window_size |
| 205 | + if pad_h > 0 or pad_w > 0 or pad_d > 0: |
| 206 | + x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) |
| 207 | + hp, wp, dp = h + pad_h, w + pad_w, d + pad_d |
| 208 | + |
| 209 | + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) |
| 210 | + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) |
| 211 | + return windows, (hp, wp, dp) |
| 212 | + |
| 213 | + |
| 214 | +def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: |
| 215 | + """ |
| 216 | + Window unpartition into original sequences and removing padding. |
| 217 | + Args: |
| 218 | + windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. |
| 219 | + with n = 1...len(spatial_dims) and window_size == window_size_i |
| 220 | + window_size (int): window size. |
| 221 | + pad (Tuple): padded spatial dims (H, W) or (H, W, D) |
| 222 | + spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. |
| 223 | +
|
| 224 | + Returns: |
| 225 | + x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. |
| 226 | + """ |
| 227 | + x: torch.Tensor |
| 228 | + if len(spatial_dims) == 2: |
| 229 | + x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) |
| 230 | + x = window_unpartition_2d(x, window_size, pad, spatial_dims) |
| 231 | + x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) |
| 232 | + return x |
| 233 | + elif len(spatial_dims) == 3: |
| 234 | + x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) |
| 235 | + x = window_unpartition_3d(x, window_size, pad, spatial_dims) |
| 236 | + x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) |
| 237 | + return x |
| 238 | + else: |
| 239 | + raise ValueError() |
| 240 | + |
| 241 | + |
| 242 | +def window_unpartition_2d( |
| 243 | + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] |
| 244 | +) -> torch.Tensor: |
| 245 | + """ |
| 246 | + Window unpartition into original sequences and removing padding. |
| 247 | + Args: |
| 248 | + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| 249 | + window_size (int): window size. |
| 250 | + pad_hw (Tuple): padded height and width (hp, wp). |
| 251 | + hw (Tuple): original height and width (H, W) before padding. |
| 252 | +
|
| 253 | + Returns: |
| 254 | + x: unpartitioned sequences with [B, H, W, C]. |
| 255 | + """ |
| 256 | + hp, wp = pad_hw |
| 257 | + h, w = hw |
| 258 | + batch = windows.shape[0] // (hp * wp // window_size // window_size) |
| 259 | + x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) |
| 260 | + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) |
| 261 | + |
| 262 | + if hp > h or wp > w: |
| 263 | + x = x[:, :h, :w, :].contiguous() |
| 264 | + return x |
| 265 | + |
| 266 | + |
| 267 | +def window_unpartition_3d( |
| 268 | + windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] |
| 269 | +) -> torch.Tensor: |
| 270 | + """ |
| 271 | + Window unpartition into original sequences and removing padding. 3d implementation. |
| 272 | + Args: |
| 273 | + windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. |
| 274 | + window_size (int): window size. |
| 275 | + pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). |
| 276 | + hwd (Tuple): original height, width and depth (H, W, D) before padding. |
| 277 | +
|
| 278 | + Returns: |
| 279 | + x: unpartitioned sequences with [B, H, W, D, C]. |
| 280 | + """ |
| 281 | + hp, wp, dp = pad_hwd |
| 282 | + h, w, d = hwd |
| 283 | + batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) |
| 284 | + x = windows.view( |
| 285 | + batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 |
| 286 | + ) |
| 287 | + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) |
| 288 | + |
| 289 | + if hp > h or wp > w or dp > d: |
| 290 | + x = x[:, :h, :w, :d, :].contiguous() |
| 291 | + return x |
0 commit comments