11
11
12
12
from __future__ import annotations
13
13
14
+ from typing import Optional , Tuple
15
+
14
16
import torch
15
17
import torch .nn as nn
18
+ import torch .nn .functional as F
16
19
17
20
from monai .utils import optional_import
18
21
@@ -23,6 +26,7 @@ class SABlock(nn.Module):
23
26
"""
24
27
A self-attention block, based on: "Dosovitskiy et al.,
25
28
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29
+ One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
26
30
"""
27
31
28
32
def __init__ (
@@ -32,13 +36,18 @@ def __init__(
32
36
dropout_rate : float = 0.0 ,
33
37
qkv_bias : bool = False ,
34
38
save_attn : bool = False ,
39
+ use_rel_pos : bool = False ,
40
+ input_size : Optional [Tuple [int , int ]] = None ,
35
41
) -> None :
36
42
"""
37
43
Args:
38
44
hidden_size (int): dimension of hidden layer.
39
45
num_heads (int): number of attention heads.
40
46
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
41
47
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
48
+ rel_pos (bool): If True, add relative positional embeddings to the attention map. Only support 2D inputs.
49
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
50
+ positional parameter size.
42
51
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
43
52
44
53
"""
@@ -62,11 +71,43 @@ def __init__(
62
71
self .scale = self .head_dim ** - 0.5
63
72
self .save_attn = save_attn
64
73
self .att_mat = torch .Tensor ()
74
+ self .use_rel_pos = use_rel_pos
75
+ self .input_size = input_size
76
+
77
+ if self .use_rel_pos :
78
+ assert input_size is not None , "Input size must be provided if using relative positional encoding."
79
+ assert len (input_size ) == 2 , "Relative positional embedding is only supported for 2D"
80
+ # initialize relative positional embeddings
81
+ self .rel_pos_h = nn .Parameter (torch .zeros (2 * input_size [0 ] - 1 , self .head_dim ))
82
+ self .rel_pos_w = nn .Parameter (torch .zeros (2 * input_size [1 ] - 1 , self .head_dim ))
83
+
84
+ def forward (self , x : torch .Tensor ):
85
+ """
86
+ Args:
87
+ x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
65
88
66
- def forward (self , x ):
89
+ Return:
90
+ torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
91
+ """
67
92
output = self .input_rearrange (self .qkv (x ))
68
93
q , k , v = output [0 ], output [1 ], output [2 ]
69
- att_mat = (torch .einsum ("blxd,blyd->blxy" , q , k ) * self .scale ).softmax (dim = - 1 )
94
+ att_mat = torch .einsum ("blxd,blyd->blxy" , q , k ) * self .scale
95
+
96
+ if self .use_rel_pos :
97
+ batch = x .shape [0 ]
98
+ h , w = self .input_size
99
+ att_mat = add_decomposed_rel_pos (
100
+ att_mat .view (batch * self .num_heads , h * w , h * w ),
101
+ q .view (batch * self .num_heads , h * w , - 1 ),
102
+ self .rel_pos_h ,
103
+ self .rel_pos_w ,
104
+ (h , w ),
105
+ (h , w ),
106
+ )
107
+ att_mat = att_mat .reshape (batch , self .num_heads , h * w , h * w )
108
+
109
+ att_mat = att_mat .softmax (dim = - 1 )
110
+
70
111
if self .save_attn :
71
112
# no gradients and new tensor;
72
113
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -78,3 +119,74 @@ def forward(self, x):
78
119
x = self .out_proj (x )
79
120
x = self .drop_output (x )
80
121
return x
122
+
123
+
124
+ def get_rel_pos (q_size : int , k_size : int , rel_pos : torch .Tensor ) -> torch .Tensor :
125
+ """
126
+ Get relative positional embeddings according to the relative positions of
127
+ query and key sizes.
128
+ Args:
129
+ q_size (int): size of query q.
130
+ k_size (int): size of key k.
131
+ rel_pos (Tensor): relative position embeddings (L, C).
132
+
133
+ Returns:
134
+ Extracted positional embeddings according to relative positions.
135
+ """
136
+ rel_pos_resized : torch .Tensor = torch .Tensor ()
137
+ max_rel_dist = int (2 * max (q_size , k_size ) - 1 )
138
+ # Interpolate rel pos if needed.
139
+ if rel_pos .shape [0 ] != max_rel_dist :
140
+ # Interpolate rel pos.
141
+ rel_pos_resized = F .interpolate (
142
+ rel_pos .reshape (1 , rel_pos .shape [0 ], - 1 ).permute (0 , 2 , 1 ), size = max_rel_dist , mode = "linear"
143
+ )
144
+ rel_pos_resized = rel_pos_resized .reshape (- 1 , max_rel_dist ).permute (1 , 0 )
145
+ else :
146
+ rel_pos_resized = rel_pos
147
+
148
+ # Scale the coords with short length if shapes for q and k are different.
149
+ q_coords = torch .arange (q_size )[:, None ] * max (k_size / q_size , 1.0 )
150
+ k_coords = torch .arange (k_size )[None , :] * max (q_size / k_size , 1.0 )
151
+ relative_coords = (q_coords - k_coords ) + (k_size - 1 ) * max (q_size / k_size , 1.0 )
152
+
153
+ return rel_pos_resized [relative_coords .long ()]
154
+
155
+
156
+ def add_decomposed_rel_pos (
157
+ attn : torch .Tensor ,
158
+ q : torch .Tensor ,
159
+ rel_pos_h : torch .Tensor ,
160
+ rel_pos_w : torch .Tensor ,
161
+ q_size : Tuple [int , int ],
162
+ k_size : Tuple [int , int ],
163
+ ) -> torch .Tensor :
164
+ """
165
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
166
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
167
+ Args:
168
+ attn (Tensor): attention map.
169
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
170
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
171
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
172
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
173
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
174
+
175
+ Returns:
176
+ attn (Tensor): attention map with added relative positional embeddings.
177
+ """
178
+ q_h , q_w = q_size
179
+ k_h , k_w = k_size
180
+ rh = get_rel_pos (q_h , k_h , rel_pos_h )
181
+ rw = get_rel_pos (q_w , k_w , rel_pos_w )
182
+
183
+ batch , _ , dim = q .shape
184
+ r_q = q .reshape (batch , q_h , q_w , dim )
185
+ rel_h = torch .einsum ("bhwc,hkc->bhwk" , r_q , rh )
186
+ rel_w = torch .einsum ("bhwc,wkc->bhwk" , r_q , rw )
187
+
188
+ attn = (attn .view (batch , q_h , q_w , k_h , k_w ) + rel_h [:, :, :, :, None ] + rel_w [:, :, :, None , :]).view (
189
+ batch , q_h * q_w , k_h * k_w
190
+ )
191
+
192
+ return attn
0 commit comments