11
11
12
12
from __future__ import annotations
13
13
14
+ from typing import Tuple
15
+
16
+ import torch
14
17
import torch .nn as nn
18
+ import torch .nn .functional as F
15
19
16
20
from monai .networks .blocks .mlp import MLPBlock
17
21
from monai .networks .blocks .selfattention import SABlock
@@ -31,6 +35,7 @@ def __init__(
31
35
dropout_rate : float = 0.0 ,
32
36
qkv_bias : bool = False ,
33
37
save_attn : bool = False ,
38
+ window_size : int = 0 ,
34
39
) -> None :
35
40
"""
36
41
Args:
@@ -40,6 +45,10 @@ def __init__(
40
45
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
41
46
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
42
47
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
48
+ window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
49
+ If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0).
50
+ If local attention is used, the input tensor should have the following shape during the forward pass: [B, H, W, C].
51
+ See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
43
52
44
53
"""
45
54
@@ -55,8 +64,70 @@ def __init__(
55
64
self .norm1 = nn .LayerNorm (hidden_size )
56
65
self .attn = SABlock (hidden_size , num_heads , dropout_rate , qkv_bias , save_attn )
57
66
self .norm2 = nn .LayerNorm (hidden_size )
67
+ self .window_size = window_size
58
68
59
69
def forward (self , x ):
60
- x = x + self .attn (self .norm1 (x ))
70
+ shortcut = x
71
+ x = self .norm1 (x )
72
+ # Window partition
73
+ if self .window_size > 0 :
74
+ H , W = x .shape [1 ], x .shape [2 ]
75
+ x , pad_hw = window_partition (x , self .window_size )
76
+
77
+ x = self .attn (x )
78
+ # Reverse window partition
79
+ if self .window_size > 0 :
80
+ x = window_unpartition (x , self .window_size , pad_hw , (H , W ))
81
+
82
+ x = shortcut + x
61
83
x = x + self .mlp (self .norm2 (x ))
62
84
return x
85
+
86
+
87
+ def window_partition (x : torch .Tensor , window_size : int ) -> Tuple [torch .Tensor , Tuple [int , int ]]:
88
+ """
89
+ Partition into non-overlapping windows with padding if needed. Support only 2D.
90
+ Args:
91
+ x (tensor): input tokens with [B, H, W, C].
92
+ window_size (int): window size.
93
+
94
+ Returns:
95
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
96
+ (Hp, Wp): padded height and width before partition
97
+ """
98
+ B , H , W , C = x .shape
99
+
100
+ pad_h = (window_size - H % window_size ) % window_size
101
+ pad_w = (window_size - W % window_size ) % window_size
102
+ if pad_h > 0 or pad_w > 0 :
103
+ x = F .pad (x , (0 , 0 , 0 , pad_w , 0 , pad_h ))
104
+ Hp , Wp = H + pad_h , W + pad_w
105
+
106
+ x = x .view (B , Hp // window_size , window_size , Wp // window_size , window_size , C )
107
+ windows = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (- 1 , window_size , window_size , C )
108
+ return windows , (Hp , Wp )
109
+
110
+
111
+ def window_unpartition (
112
+ windows : torch .Tensor , window_size : int , pad_hw : Tuple [int , int ], hw : Tuple [int , int ]
113
+ ) -> torch .Tensor :
114
+ """
115
+ Window unpartition into original sequences and removing padding.
116
+ Args:
117
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
118
+ window_size (int): window size.
119
+ pad_hw (Tuple): padded height and width (Hp, Wp).
120
+ hw (Tuple): original height and width (H, W) before padding.
121
+
122
+ Returns:
123
+ x: unpartitioned sequences with [B, H, W, C].
124
+ """
125
+ Hp , Wp = pad_hw
126
+ H , W = hw
127
+ B = windows .shape [0 ] // (Hp * Wp // window_size // window_size )
128
+ x = windows .view (B , Hp // window_size , Wp // window_size , window_size , window_size , - 1 )
129
+ x = x .permute (0 , 1 , 3 , 2 , 4 , 5 ).contiguous ().view (B , Hp , Wp , - 1 )
130
+
131
+ if Hp > H or Wp > W :
132
+ x = x [:, :H , :W , :].contiguous ()
133
+ return x
0 commit comments