Skip to content

Commit dbbab7c

Browse files
committed
refacto
Signed-off-by: vgrau98 <[email protected]>
1 parent 7b971bf commit dbbab7c

File tree

3 files changed

+210
-181
lines changed

3 files changed

+210
-181
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# you may not use this file except in compliance with the License.
2+
# You may obtain a copy of the License at
3+
# http://www.apache.org/licenses/LICENSE-2.0
4+
# Unless required by applicable law or agreed to in writing, software
5+
# distributed under the License is distributed on an "AS IS" BASIS,
6+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7+
# See the License for the specific language governing permissions and
8+
# limitations under the License.
9+
10+
from __future__ import annotations
11+
12+
from typing import Tuple
13+
14+
import torch
15+
import torch.nn.functional as F
16+
17+
from monai.utils import optional_import
18+
19+
rearrange, _ = optional_import("einops", name="rearrange")
20+
21+
22+
def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]:
23+
"""
24+
Partition into non-overlapping windows with padding if needed. Support 2D and 3D.
25+
Args:
26+
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size)
27+
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D)
28+
window_size (int): window size
29+
30+
Returns:
31+
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C].
32+
with n = 1...len(input_size) and window_size_i == window_size.
33+
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size)
34+
"""
35+
if x.shape[1] != int(torch.prod(torch.tensor(input_size))):
36+
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product")
37+
38+
if len(input_size) == 2:
39+
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1])
40+
x, pad_hw = window_partition_2d(x, window_size)
41+
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size)
42+
return x, pad_hw
43+
elif len(input_size) == 3:
44+
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2])
45+
x, pad_hwd = window_partition_3d(x, window_size)
46+
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size)
47+
return x, pad_hwd
48+
else:
49+
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ")
50+
51+
52+
def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
53+
"""
54+
Partition into non-overlapping windows with padding if needed. Support only 2D.
55+
Args:
56+
x (tensor): input tokens with [B, H, W, C].
57+
window_size (int): window size.
58+
59+
Returns:
60+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
61+
(Hp, Wp): padded height and width before partition
62+
"""
63+
batch, h, w, c = x.shape
64+
65+
pad_h = (window_size - h % window_size) % window_size
66+
pad_w = (window_size - w % window_size) % window_size
67+
if pad_h > 0 or pad_w > 0:
68+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
69+
hp, wp = h + pad_h, w + pad_w
70+
71+
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c)
72+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
73+
return windows, (hp, wp)
74+
75+
76+
def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
77+
"""
78+
Partition into non-overlapping windows with padding if needed. 3d implementation.
79+
Args:
80+
x (tensor): input tokens with [B, H, W, D, C].
81+
window_size (int): window size.
82+
83+
Returns:
84+
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C].
85+
(Hp, Wp, Dp): padded height, width and depth before partition
86+
"""
87+
batch, h, w, d, c = x.shape
88+
89+
pad_h = (window_size - h % window_size) % window_size
90+
pad_w = (window_size - w % window_size) % window_size
91+
pad_d = (window_size - d % window_size) % window_size
92+
if pad_h > 0 or pad_w > 0 or pad_d > 0:
93+
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h))
94+
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d
95+
96+
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
97+
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
98+
return windows, (hp, wp, dp)
99+
100+
101+
def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor:
102+
"""
103+
Window unpartition into original sequences and removing padding.
104+
Args:
105+
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C].
106+
with n = 1...len(spatial_dims) and window_size == window_size_i
107+
window_size (int): window size.
108+
pad (Tuple): padded spatial dims (H, W) or (H, W, D)
109+
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding.
110+
111+
Returns:
112+
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C].
113+
"""
114+
x: torch.Tensor
115+
if len(spatial_dims) == 2:
116+
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size)
117+
x = window_unpartition_2d(x, window_size, pad, spatial_dims)
118+
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1])
119+
return x
120+
elif len(spatial_dims) == 3:
121+
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size)
122+
x = window_unpartition_3d(x, window_size, pad, spatial_dims)
123+
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2])
124+
return x
125+
else:
126+
raise ValueError()
127+
128+
129+
def window_unpartition_2d(
130+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
131+
) -> torch.Tensor:
132+
"""
133+
Window unpartition into original sequences and removing padding.
134+
Args:
135+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
136+
window_size (int): window size.
137+
pad_hw (Tuple): padded height and width (hp, wp).
138+
hw (Tuple): original height and width (H, W) before padding.
139+
140+
Returns:
141+
x: unpartitioned sequences with [B, H, W, C].
142+
"""
143+
hp, wp = pad_hw
144+
h, w = hw
145+
batch = windows.shape[0] // (hp * wp // window_size // window_size)
146+
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1)
147+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1)
148+
149+
if hp > h or wp > w:
150+
x = x[:, :h, :w, :].contiguous()
151+
return x
152+
153+
154+
def window_unpartition_3d(
155+
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int]
156+
) -> torch.Tensor:
157+
"""
158+
Window unpartition into original sequences and removing padding. 3d implementation.
159+
Args:
160+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C].
161+
window_size (int): window size.
162+
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp).
163+
hwd (Tuple): original height, width and depth (H, W, D) before padding.
164+
165+
Returns:
166+
x: unpartitioned sequences with [B, H, W, D, C].
167+
"""
168+
hp, wp, dp = pad_hwd
169+
h, w, d = hwd
170+
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size)
171+
x = windows.view(
172+
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1
173+
)
174+
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1)
175+
176+
if hp > h or wp > w or dp > d:
177+
x = x[:, :h, :w, :d, :].contiguous()
178+
return x

monai/networks/blocks/selfattention.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Tuple
15+
1416
import torch
1517
import torch.nn as nn
1618

19+
from monai.networks.blocks.attention_utils import window_partition, window_unpartition
1720
from monai.utils import optional_import
1821

1922
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
@@ -23,6 +26,9 @@ class SABlock(nn.Module):
2326
"""
2427
A self-attention block, based on: "Dosovitskiy et al.,
2528
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29+
and some additional features:
30+
- local window attention
31+
2632
"""
2733

2834
def __init__(
@@ -32,6 +38,8 @@ def __init__(
3238
dropout_rate: float = 0.0,
3339
qkv_bias: bool = False,
3440
save_attn: bool = False,
41+
window_size: int = 0,
42+
input_size: Tuple = (),
3543
) -> None:
3644
"""
3745
Args:
@@ -40,6 +48,10 @@ def __init__(
4048
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4149
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
4250
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
51+
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
52+
If 0, global attention used.
53+
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
54+
input_size (Tuple): spatial input dimensions (h, w, and d). Has to be set if local window attention is used.
4355
4456
"""
4557

@@ -51,6 +63,11 @@ def __init__(
5163
if hidden_size % num_heads != 0:
5264
raise ValueError("hidden size should be divisible by num_heads.")
5365

66+
if window_size > 0 and len(input_size) not in [2, 3]:
67+
raise ValueError(
68+
"If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)"
69+
)
70+
5471
self.num_heads = num_heads
5572
self.out_proj = nn.Linear(hidden_size, hidden_size)
5673
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
@@ -62,12 +79,18 @@ def __init__(
6279
self.scale = self.head_dim**-0.5
6380
self.save_attn = save_attn
6481
self.att_mat = torch.Tensor()
82+
self.window_size = window_size
83+
self.input_size = input_size
6584

6685
def forward(self, x):
6786
"""
6887
Args:
6988
x (Tensor): [b x (s_dim_1 * … * s_dim_n) x dim]
7089
"""
90+
# Window partition
91+
if self.window_size > 0:
92+
x, pad = window_partition(x, self.window_size, self.input_size)
93+
7194
output = self.input_rearrange(self.qkv(x))
7295
q, k, v = output[0], output[1], output[2]
7396
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
@@ -81,4 +104,9 @@ def forward(self, x):
81104
x = self.out_rearrange(x)
82105
x = self.out_proj(x)
83106
x = self.drop_output(x)
107+
108+
# Reverse window partition
109+
if self.window_size > 0:
110+
x = window_unpartition(x, self.window_size, pad, self.input_size)
111+
84112
return x

0 commit comments

Comments
 (0)