|
| 1 | +# you may not use this file except in compliance with the License. |
| 2 | +# You may obtain a copy of the License at |
| 3 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 4 | +# Unless required by applicable law or agreed to in writing, software |
| 5 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 6 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 7 | +# See the License for the specific language governing permissions and |
| 8 | +# limitations under the License. |
| 9 | + |
| 10 | +from __future__ import annotations |
| 11 | + |
| 12 | +from typing import Tuple |
| 13 | + |
| 14 | +import torch |
| 15 | +import torch.nn.functional as F |
| 16 | + |
| 17 | +from monai.utils import optional_import |
| 18 | + |
| 19 | +rearrange, _ = optional_import("einops", name="rearrange") |
| 20 | + |
| 21 | + |
| 22 | +def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]: |
| 23 | + """ |
| 24 | + Partition into non-overlapping windows with padding if needed. Support 2D and 3D. |
| 25 | + Args: |
| 26 | + x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size) |
| 27 | + input_size (Tuple): input spatial dimension: (H, W) or (H, W, D) |
| 28 | + window_size (int): window size |
| 29 | +
|
| 30 | + Returns: |
| 31 | + windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C]. |
| 32 | + with n = 1...len(input_size) and window_size_i == window_size. |
| 33 | + (S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size) |
| 34 | + """ |
| 35 | + if x.shape[1] != int(torch.prod(torch.tensor(input_size))): |
| 36 | + raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product") |
| 37 | + |
| 38 | + if len(input_size) == 2: |
| 39 | + x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1]) |
| 40 | + x, pad_hw = window_partition_2d(x, window_size) |
| 41 | + x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size) |
| 42 | + return x, pad_hw |
| 43 | + elif len(input_size) == 3: |
| 44 | + x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2]) |
| 45 | + x, pad_hwd = window_partition_3d(x, window_size) |
| 46 | + x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size) |
| 47 | + return x, pad_hwd |
| 48 | + else: |
| 49 | + raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ") |
| 50 | + |
| 51 | + |
| 52 | +def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| 53 | + """ |
| 54 | + Partition into non-overlapping windows with padding if needed. Support only 2D. |
| 55 | + Args: |
| 56 | + x (tensor): input tokens with [B, H, W, C]. |
| 57 | + window_size (int): window size. |
| 58 | +
|
| 59 | + Returns: |
| 60 | + windows: windows after partition with [B * num_windows, window_size, window_size, C]. |
| 61 | + (Hp, Wp): padded height and width before partition |
| 62 | + """ |
| 63 | + batch, h, w, c = x.shape |
| 64 | + |
| 65 | + pad_h = (window_size - h % window_size) % window_size |
| 66 | + pad_w = (window_size - w % window_size) % window_size |
| 67 | + if pad_h > 0 or pad_w > 0: |
| 68 | + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) |
| 69 | + hp, wp = h + pad_h, w + pad_w |
| 70 | + |
| 71 | + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c) |
| 72 | + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) |
| 73 | + return windows, (hp, wp) |
| 74 | + |
| 75 | + |
| 76 | +def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]: |
| 77 | + """ |
| 78 | + Partition into non-overlapping windows with padding if needed. 3d implementation. |
| 79 | + Args: |
| 80 | + x (tensor): input tokens with [B, H, W, D, C]. |
| 81 | + window_size (int): window size. |
| 82 | +
|
| 83 | + Returns: |
| 84 | + windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C]. |
| 85 | + (Hp, Wp, Dp): padded height, width and depth before partition |
| 86 | + """ |
| 87 | + batch, h, w, d, c = x.shape |
| 88 | + |
| 89 | + pad_h = (window_size - h % window_size) % window_size |
| 90 | + pad_w = (window_size - w % window_size) % window_size |
| 91 | + pad_d = (window_size - d % window_size) % window_size |
| 92 | + if pad_h > 0 or pad_w > 0 or pad_d > 0: |
| 93 | + x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h)) |
| 94 | + hp, wp, dp = h + pad_h, w + pad_w, d + pad_d |
| 95 | + |
| 96 | + x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c) |
| 97 | + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c) |
| 98 | + return windows, (hp, wp, dp) |
| 99 | + |
| 100 | + |
| 101 | +def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor: |
| 102 | + """ |
| 103 | + Window unpartition into original sequences and removing padding. |
| 104 | + Args: |
| 105 | + windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C]. |
| 106 | + with n = 1...len(spatial_dims) and window_size == window_size_i |
| 107 | + window_size (int): window size. |
| 108 | + pad (Tuple): padded spatial dims (H, W) or (H, W, D) |
| 109 | + spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C]. |
| 113 | + """ |
| 114 | + x: torch.Tensor |
| 115 | + if len(spatial_dims) == 2: |
| 116 | + x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size) |
| 117 | + x = window_unpartition_2d(x, window_size, pad, spatial_dims) |
| 118 | + x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1]) |
| 119 | + return x |
| 120 | + elif len(spatial_dims) == 3: |
| 121 | + x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size) |
| 122 | + x = window_unpartition_3d(x, window_size, pad, spatial_dims) |
| 123 | + x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2]) |
| 124 | + return x |
| 125 | + else: |
| 126 | + raise ValueError() |
| 127 | + |
| 128 | + |
| 129 | +def window_unpartition_2d( |
| 130 | + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] |
| 131 | +) -> torch.Tensor: |
| 132 | + """ |
| 133 | + Window unpartition into original sequences and removing padding. |
| 134 | + Args: |
| 135 | + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. |
| 136 | + window_size (int): window size. |
| 137 | + pad_hw (Tuple): padded height and width (hp, wp). |
| 138 | + hw (Tuple): original height and width (H, W) before padding. |
| 139 | +
|
| 140 | + Returns: |
| 141 | + x: unpartitioned sequences with [B, H, W, C]. |
| 142 | + """ |
| 143 | + hp, wp = pad_hw |
| 144 | + h, w = hw |
| 145 | + batch = windows.shape[0] // (hp * wp // window_size // window_size) |
| 146 | + x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1) |
| 147 | + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1) |
| 148 | + |
| 149 | + if hp > h or wp > w: |
| 150 | + x = x[:, :h, :w, :].contiguous() |
| 151 | + return x |
| 152 | + |
| 153 | + |
| 154 | +def window_unpartition_3d( |
| 155 | + windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int] |
| 156 | +) -> torch.Tensor: |
| 157 | + """ |
| 158 | + Window unpartition into original sequences and removing padding. 3d implementation. |
| 159 | + Args: |
| 160 | + windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C]. |
| 161 | + window_size (int): window size. |
| 162 | + pad_hwd (Tuple): padded height, width and depth (hp, wp, dp). |
| 163 | + hwd (Tuple): original height, width and depth (H, W, D) before padding. |
| 164 | +
|
| 165 | + Returns: |
| 166 | + x: unpartitioned sequences with [B, H, W, D, C]. |
| 167 | + """ |
| 168 | + hp, wp, dp = pad_hwd |
| 169 | + h, w, d = hwd |
| 170 | + batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size) |
| 171 | + x = windows.view( |
| 172 | + batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1 |
| 173 | + ) |
| 174 | + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1) |
| 175 | + |
| 176 | + if hp > h or wp > w or dp > d: |
| 177 | + x = x[:, :h, :w, :d, :].contiguous() |
| 178 | + return x |
0 commit comments