Skip to content

Commit 0132e6b

Browse files
committed
refacto
Signed-off-by: vgrau98 <[email protected]>
1 parent 7d82d8a commit 0132e6b

File tree

3 files changed

+195
-186
lines changed

3 files changed

+195
-186
lines changed

monai/networks/blocks/attention_utils.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import torch.nn.functional as F
1616
from torch import nn
1717

18+
from monai.utils import optional_import
19+
20+
rearrange, _ = optional_import("einops", name="rearrange")
21+
1822

1923
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
2024
"""
@@ -126,3 +130,162 @@ def add_decomposed_rel_pos(
126130
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
127131

128132
return attn
133+
134+
135+
def window_partition(x: torch.Tensor, window_size: int, input_size: Tuple = ()) -> Tuple[torch.Tensor, Tuple]:
136+
"""
137+
Partition into non-overlapping windows with padding if needed. Support 2D and 3D.
138+
Args:
139+
x (tensor): input tokens with [B, s_dim_1 * ... * s_dim_n, C]. with n = 1...len(input_size)
140+
input_size (Tuple): input spatial dimension: (H, W) or (H, W, D)
141+
window_size (int): window size
142+
143+
Returns:
144+
windows: windows after partition with [B * num_windows, window_size_1 * ... * window_size_n, C].
145+
with n = 1...len(input_size) and window_size_i == window_size.
146+
(S_DIM_1p, ...,S_DIM_np): padded spatial dimensions before partition with n = 1...len(input_size)
147+
"""
148+
if x.shape[1] != int(torch.prod(torch.tensor(input_size))):
149+
raise ValueError(f"Input tensor spatial dimension {x.shape[1]} should be equal to {input_size} product")
150+
151+
if len(input_size) == 2:
152+
x = rearrange(x, "b (h w) c -> b h w c", h=input_size[0], w=input_size[1])
153+
x, pad_hw = window_partition_2d(x, window_size)
154+
x = rearrange(x, "b h w c -> b (h w) c", h=window_size, w=window_size)
155+
return x, pad_hw
156+
elif len(input_size) == 3:
157+
x = rearrange(x, "b (h w d) c -> b h w d c", h=input_size[0], w=input_size[1], d=input_size[2])
158+
x, pad_hwd = window_partition_3d(x, window_size)
159+
x = rearrange(x, "b h w d c -> b (h w d) c", h=window_size, w=window_size, d=window_size)
160+
return x, pad_hwd
161+
else:
162+
raise ValueError(f"input_size cannot be length {len(input_size)}. It can be composed of 2 or 3 elements only. ")
163+
164+
165+
def window_partition_2d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
166+
"""
167+
Partition into non-overlapping windows with padding if needed. Support only 2D.
168+
Args:
169+
x (tensor): input tokens with [B, H, W, C].
170+
window_size (int): window size.
171+
172+
Returns:
173+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
174+
(Hp, Wp): padded height and width before partition
175+
"""
176+
batch, h, w, c = x.shape
177+
178+
pad_h = (window_size - h % window_size) % window_size
179+
pad_w = (window_size - w % window_size) % window_size
180+
if pad_h > 0 or pad_w > 0:
181+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
182+
hp, wp = h + pad_h, w + pad_w
183+
184+
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, c)
185+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
186+
return windows, (hp, wp)
187+
188+
189+
def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
190+
"""
191+
Partition into non-overlapping windows with padding if needed. 3d implementation.
192+
Args:
193+
x (tensor): input tokens with [B, H, W, D, C].
194+
window_size (int): window size.
195+
196+
Returns:
197+
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C].
198+
(Hp, Wp, Dp): padded height, width and depth before partition
199+
"""
200+
batch, h, w, d, c = x.shape
201+
202+
pad_h = (window_size - h % window_size) % window_size
203+
pad_w = (window_size - w % window_size) % window_size
204+
pad_d = (window_size - d % window_size) % window_size
205+
if pad_h > 0 or pad_w > 0 or pad_d > 0:
206+
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h))
207+
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d
208+
209+
x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
210+
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
211+
return windows, (hp, wp, dp)
212+
213+
214+
def window_unpartition(windows: torch.Tensor, window_size: int, pad: Tuple, spatial_dims: Tuple) -> torch.Tensor:
215+
"""
216+
Window unpartition into original sequences and removing padding.
217+
Args:
218+
windows (tensor): input tokens with [B * num_windows, window_size_1, ..., window_size_n, C].
219+
with n = 1...len(spatial_dims) and window_size == window_size_i
220+
window_size (int): window size.
221+
pad (Tuple): padded spatial dims (H, W) or (H, W, D)
222+
spatial_dims (Tuple): original spatial dimensions - (H, W) or (H, W, D) - before padding.
223+
224+
Returns:
225+
x: unpartitioned sequences with [B, s_dim_1, ..., s_dim_n, C].
226+
"""
227+
x: torch.Tensor
228+
if len(spatial_dims) == 2:
229+
x = rearrange(windows, "b (h w) c -> b h w c", h=window_size, w=window_size)
230+
x = window_unpartition_2d(x, window_size, pad, spatial_dims)
231+
x = rearrange(x, "b h w c -> b (h w) c", h=spatial_dims[0], w=spatial_dims[1])
232+
return x
233+
elif len(spatial_dims) == 3:
234+
x = rearrange(windows, "b (h w d) c -> b h w d c", h=window_size, w=window_size, d=window_size)
235+
x = window_unpartition_3d(x, window_size, pad, spatial_dims)
236+
x = rearrange(x, "b h w d c -> b (h w d) c", h=spatial_dims[0], w=spatial_dims[1], d=spatial_dims[2])
237+
return x
238+
else:
239+
raise ValueError()
240+
241+
242+
def window_unpartition_2d(
243+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
244+
) -> torch.Tensor:
245+
"""
246+
Window unpartition into original sequences and removing padding.
247+
Args:
248+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
249+
window_size (int): window size.
250+
pad_hw (Tuple): padded height and width (hp, wp).
251+
hw (Tuple): original height and width (H, W) before padding.
252+
253+
Returns:
254+
x: unpartitioned sequences with [B, H, W, C].
255+
"""
256+
hp, wp = pad_hw
257+
h, w = hw
258+
batch = windows.shape[0] // (hp * wp // window_size // window_size)
259+
x = windows.view(batch, hp // window_size, wp // window_size, window_size, window_size, -1)
260+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(batch, hp, wp, -1)
261+
262+
if hp > h or wp > w:
263+
x = x[:, :h, :w, :].contiguous()
264+
return x
265+
266+
267+
def window_unpartition_3d(
268+
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int]
269+
) -> torch.Tensor:
270+
"""
271+
Window unpartition into original sequences and removing padding. 3d implementation.
272+
Args:
273+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C].
274+
window_size (int): window size.
275+
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp).
276+
hwd (Tuple): original height, width and depth (H, W, D) before padding.
277+
278+
Returns:
279+
x: unpartitioned sequences with [B, H, W, D, C].
280+
"""
281+
hp, wp, dp = pad_hwd
282+
h, w, d = hwd
283+
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size)
284+
x = windows.view(
285+
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1
286+
)
287+
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1)
288+
289+
if hp > h or wp > w or dp > d:
290+
x = x[:, :h, :w, :d, :].contiguous()
291+
return x

monai/networks/blocks/selfattention.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch.nn as nn
1919

2020
from monai.networks.layers.utils import get_rel_pos_embedding_layer
21+
from monai.networks.blocks.attention_utils import window_partition, window_unpartition
2122
from monai.utils import optional_import
2223

2324
xops, has_xformers = optional_import("xformers.ops")
@@ -26,9 +27,14 @@
2627

2728
class SABlock(nn.Module):
2829
"""
29-
A self-attention block, based on: "Dosovitskiy et al.,
30-
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
31-
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
30+
A self-attention block, based on: "Dosovitskiy et al.,
31+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
32+
<<<<<<< HEAD
33+
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
34+
=======
35+
and some additional features:
36+
- local window attention
37+
>>>>>>> f7aca872 (refacto)
3238
"""
3339

3440
def __init__(
@@ -43,6 +49,7 @@ def __init__(
4349
causal: bool = False,
4450
sequence_length: int | None = None,
4551
use_flash_attention: bool = False,
52+
window_size: int = 0,
4653
) -> None:
4754
"""
4855
Args:
@@ -53,11 +60,13 @@ def __init__(
5360
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
5461
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
5562
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
56-
positional parameter size.
63+
positional parameter size. Has to be set if local window attention is used
5764
causal (bool): wether to use causal attention. If true `sequence_length` has to be set
5865
sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
5966
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
60-
67+
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
68+
If 0, global attention used.
69+
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
6170
"""
6271

6372
super().__init__()
@@ -81,6 +90,10 @@ def __init__(
8190

8291
if use_flash_attention and not has_xformers:
8392
raise ValueError("use_flash_attention is True but xformers is not installed.")
93+
if window_size > 0 and len(input_size) not in [2, 3]:
94+
raise ValueError(
95+
"If local window attention is used (window_size > 0), input_size should be specified: (h, w) or (h, w, d)"
96+
)
8497

8598
self.num_heads = num_heads
8699
self.out_proj = nn.Linear(hidden_size, hidden_size)
@@ -101,6 +114,7 @@ def __init__(
101114
if rel_pos_embedding is not None
102115
else None
103116
)
117+
self.window_size = window_size
104118
self.input_size = input_size
105119

106120
if causal and sequence_length is not None:
@@ -119,6 +133,10 @@ def forward(self, x: torch.Tensor):
119133
Return:
120134
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
121135
"""
136+
137+
if self.window_size > 0:
138+
x, pad = window_partition(x, self.window_size, self.input_size)
139+
122140
_, t, _ = x.size()
123141
output = self.input_rearrange(self.qkv(x)) # 3 x B x (s_dim_1 * ... * s_dim_n) x h x C/h
124142
q, k, v = output[0], output[1], output[2]
@@ -156,4 +174,9 @@ def forward(self, x: torch.Tensor):
156174
x = self.out_rearrange(x)
157175
x = self.out_proj(x)
158176
x = self.drop_output(x)
177+
178+
# Reverse window partition
179+
if self.window_size > 0:
180+
x = window_unpartition(x, self.window_size, pad, self.input_size)
181+
159182
return x

0 commit comments

Comments
 (0)