Skip to content

Commit 8f95ed6

Browse files
committed
transformer block local window attention
Signed-off-by: vgrau98 <[email protected]>
1 parent b3d7a48 commit 8f95ed6

File tree

1 file changed

+72
-1
lines changed

1 file changed

+72
-1
lines changed

monai/networks/blocks/transformerblock.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Tuple
15+
16+
import torch
1417
import torch.nn as nn
18+
import torch.nn.functional as F
1519

1620
from monai.networks.blocks.mlp import MLPBlock
1721
from monai.networks.blocks.selfattention import SABlock
@@ -31,6 +35,7 @@ def __init__(
3135
dropout_rate: float = 0.0,
3236
qkv_bias: bool = False,
3337
save_attn: bool = False,
38+
window_size: int = 0,
3439
) -> None:
3540
"""
3641
Args:
@@ -40,6 +45,10 @@ def __init__(
4045
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4146
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
4247
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
48+
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
49+
If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0).
50+
If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C].
51+
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
4352
4453
"""
4554

@@ -55,8 +64,70 @@ def __init__(
5564
self.norm1 = nn.LayerNorm(hidden_size)
5665
self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
5766
self.norm2 = nn.LayerNorm(hidden_size)
67+
self.window_size = window_size
5868

5969
def forward(self, x):
60-
x = x + self.attn(self.norm1(x))
70+
shortcut = x
71+
x = self.norm1(x)
72+
# Window partition
73+
if self.window_size > 0:
74+
H, W = x.shape[1], x.shape[2]
75+
x, pad_hw = window_partition(x, self.window_size)
76+
77+
x = self.attn(x)
78+
# Reverse window partition
79+
if self.window_size > 0:
80+
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
81+
82+
x = shortcut + x
6183
x = x + self.mlp(self.norm2(x))
6284
return x
85+
86+
87+
def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
88+
"""
89+
Partition into non-overlapping windows with padding if needed. Support only 2D.
90+
Args:
91+
x (tensor): input tokens with [B, H, W, C].
92+
window_size (int): window size.
93+
94+
Returns:
95+
windows: windows after partition with [B * num_windows, window_size, window_size, C].
96+
(Hp, Wp): padded height and width before partition
97+
"""
98+
B, H, W, C = x.shape
99+
100+
pad_h = (window_size - H % window_size) % window_size
101+
pad_w = (window_size - W % window_size) % window_size
102+
if pad_h > 0 or pad_w > 0:
103+
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
104+
Hp, Wp = H + pad_h, W + pad_w
105+
106+
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
107+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
108+
return windows, (Hp, Wp)
109+
110+
111+
def window_unpartition(
112+
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
113+
) -> torch.Tensor:
114+
"""
115+
Window unpartition into original sequences and removing padding.
116+
Args:
117+
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
118+
window_size (int): window size.
119+
pad_hw (Tuple): padded height and width (Hp, Wp).
120+
hw (Tuple): original height and width (H, W) before padding.
121+
122+
Returns:
123+
x: unpartitioned sequences with [B, H, W, C].
124+
"""
125+
Hp, Wp = pad_hw
126+
H, W = hw
127+
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
128+
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
129+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
130+
131+
if Hp > H or Wp > W:
132+
x = x[:, :H, :W, :].contiguous()
133+
return x

0 commit comments

Comments
 (0)