Skip to content

Commit 67ccdbe

Browse files
committed
attention-rel-pos-embedd
Signed-off-by: vgrau98 <[email protected]>
1 parent b3d7a48 commit 67ccdbe

File tree

1 file changed

+114
-2
lines changed

1 file changed

+114
-2
lines changed

monai/networks/blocks/selfattention.py

Lines changed: 114 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Optional, Tuple
15+
1416
import torch
1517
import torch.nn as nn
18+
import torch.nn.functional as F
1619

1720
from monai.utils import optional_import
1821

@@ -23,6 +26,7 @@ class SABlock(nn.Module):
2326
"""
2427
A self-attention block, based on: "Dosovitskiy et al.,
2528
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29+
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
2630
"""
2731

2832
def __init__(
@@ -32,13 +36,18 @@ def __init__(
3236
dropout_rate: float = 0.0,
3337
qkv_bias: bool = False,
3438
save_attn: bool = False,
39+
use_rel_pos: bool = False,
40+
input_size: Optional[Tuple[int, int]] = None,
3541
) -> None:
3642
"""
3743
Args:
3844
hidden_size (int): dimension of hidden layer.
3945
num_heads (int): number of attention heads.
4046
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4147
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
48+
rel_pos (bool): If True, add relative positional embeddings to the attention map. Only support 2D inputs.
49+
input_size (tuple(int, int) or None): Input resolution for calculating the relative
50+
positional parameter size.
4251
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
4352
4453
"""
@@ -62,11 +71,43 @@ def __init__(
6271
self.scale = self.head_dim**-0.5
6372
self.save_attn = save_attn
6473
self.att_mat = torch.Tensor()
74+
self.use_rel_pos = use_rel_pos
75+
self.input_size = input_size
76+
77+
if self.use_rel_pos:
78+
assert input_size is not None, "Input size must be provided if using relative positional encoding."
79+
assert len(input_size) == 2, "Relative positional embedding is only supported for 2D"
80+
# initialize relative positional embeddings
81+
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, self.head_dim))
82+
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, self.head_dim))
83+
84+
def forward(self, x: torch.Tensor):
85+
"""
86+
Args:
87+
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
6588
66-
def forward(self, x):
89+
Return:
90+
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
91+
"""
6792
output = self.input_rearrange(self.qkv(x))
6893
q, k, v = output[0], output[1], output[2]
69-
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
94+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
95+
96+
if self.use_rel_pos:
97+
batch = x.shape[0]
98+
h, w = self.input_size
99+
att_mat = add_decomposed_rel_pos(
100+
att_mat.view(batch * self.num_heads, h * w, h * w),
101+
q.view(batch * self.num_heads, h * w, -1),
102+
self.rel_pos_h,
103+
self.rel_pos_w,
104+
(h, w),
105+
(h, w),
106+
)
107+
att_mat = att_mat.reshape(batch, self.num_heads, h * w, h * w)
108+
109+
att_mat = att_mat.softmax(dim=-1)
110+
70111
if self.save_attn:
71112
# no gradients and new tensor;
72113
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
@@ -78,3 +119,74 @@ def forward(self, x):
78119
x = self.out_proj(x)
79120
x = self.drop_output(x)
80121
return x
122+
123+
124+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
125+
"""
126+
Get relative positional embeddings according to the relative positions of
127+
query and key sizes.
128+
Args:
129+
q_size (int): size of query q.
130+
k_size (int): size of key k.
131+
rel_pos (Tensor): relative position embeddings (L, C).
132+
133+
Returns:
134+
Extracted positional embeddings according to relative positions.
135+
"""
136+
rel_pos_resized: torch.Tensor = torch.Tensor()
137+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
138+
# Interpolate rel pos if needed.
139+
if rel_pos.shape[0] != max_rel_dist:
140+
# Interpolate rel pos.
141+
rel_pos_resized = F.interpolate(
142+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
143+
)
144+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
145+
else:
146+
rel_pos_resized = rel_pos
147+
148+
# Scale the coords with short length if shapes for q and k are different.
149+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
150+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
151+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
152+
153+
return rel_pos_resized[relative_coords.long()]
154+
155+
156+
def add_decomposed_rel_pos(
157+
attn: torch.Tensor,
158+
q: torch.Tensor,
159+
rel_pos_h: torch.Tensor,
160+
rel_pos_w: torch.Tensor,
161+
q_size: Tuple[int, int],
162+
k_size: Tuple[int, int],
163+
) -> torch.Tensor:
164+
"""
165+
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
166+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
167+
Args:
168+
attn (Tensor): attention map.
169+
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
170+
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
171+
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
172+
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
173+
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
174+
175+
Returns:
176+
attn (Tensor): attention map with added relative positional embeddings.
177+
"""
178+
q_h, q_w = q_size
179+
k_h, k_w = k_size
180+
rh = get_rel_pos(q_h, k_h, rel_pos_h)
181+
rw = get_rel_pos(q_w, k_w, rel_pos_w)
182+
183+
batch, _, dim = q.shape
184+
r_q = q.reshape(batch, q_h, q_w, dim)
185+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
186+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)
187+
188+
attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
189+
batch, q_h * q_w, k_h * k_w
190+
)
191+
192+
return attn

0 commit comments

Comments
 (0)