Skip to content

Commit fdbc611

Browse files
vgrau98ericspodpre-commit-ci[bot]KumoLiu
authored andcommitted
[Attention block] relative positional embedding (Project-MONAI#7346)
Fixes Project-MONAI#7356 ### Description Add relative positinoal embedding in attention block as described in https://arxiv.org/pdf/2112.01526.pdf Largely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py Can be useful for Project-MONAI#6357 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 <[email protected]> Signed-off-by: vgrau98 <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Signed-off-by: Mark Graham <[email protected]>
1 parent e15b570 commit fdbc611

File tree

7 files changed

+262
-10
lines changed

7 files changed

+262
-10
lines changed

docs/source/networks.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ Blocks
248248
.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
249249
:members:
250250

251+
`Attention utilities`
252+
~~~~~~~~~~~~~~~~~~~~~
253+
.. automodule:: monai.networks.blocks.attention_utils
254+
.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos
255+
.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos
256+
251257
N-Dim Fourier Transform
252258
~~~~~~~~~~~~~~~~~~~~~~~~
253259
.. automodule:: monai.networks.blocks.fft_utils_t
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# you may not use this file except in compliance with the License.
2+
# You may obtain a copy of the License at
3+
# http://www.apache.org/licenses/LICENSE-2.0
4+
# Unless required by applicable law or agreed to in writing, software
5+
# distributed under the License is distributed on an "AS IS" BASIS,
6+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7+
# See the License for the specific language governing permissions and
8+
# limitations under the License.
9+
10+
from __future__ import annotations
11+
12+
from typing import Tuple
13+
14+
import torch
15+
import torch.nn.functional as F
16+
from torch import nn
17+
18+
19+
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
20+
"""
21+
Get relative positional embeddings according to the relative positions of
22+
query and key sizes.
23+
24+
Args:
25+
q_size (int): size of query q.
26+
k_size (int): size of key k.
27+
rel_pos (Tensor): relative position embeddings (L, C).
28+
29+
Returns:
30+
Extracted positional embeddings according to relative positions.
31+
"""
32+
rel_pos_resized: torch.Tensor = torch.Tensor()
33+
max_rel_dist = int(2 * max(q_size, k_size) - 1)
34+
# Interpolate rel pos if needed.
35+
if rel_pos.shape[0] != max_rel_dist:
36+
# Interpolate rel pos.
37+
rel_pos_resized = F.interpolate(
38+
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
39+
)
40+
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
41+
else:
42+
rel_pos_resized = rel_pos
43+
44+
# Scale the coords with short length if shapes for q and k are different.
45+
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
46+
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
47+
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
48+
49+
return rel_pos_resized[relative_coords.long()]
50+
51+
52+
def add_decomposed_rel_pos(
53+
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
54+
) -> torch.Tensor:
55+
r"""
56+
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
57+
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
58+
59+
Only 2D and 3D are supported.
60+
61+
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
62+
`d` apart will have the same embedding value (unlike absolute positional embedding).
63+
64+
.. math::
65+
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
66+
67+
where
68+
69+
.. math::
70+
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
71+
72+
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
73+
respectively spatial positions of element :math:`i` and :math:`j`
74+
75+
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
76+
77+
.. math::
78+
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
79+
80+
with :math:`n = 1...dim`
81+
82+
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
83+
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
84+
85+
Args:
86+
attn (Tensor): attention map.
87+
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
88+
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
89+
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
90+
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
91+
92+
Returns:
93+
attn (Tensor): attention logits with added relative positional embeddings.
94+
"""
95+
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
96+
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])
97+
98+
batch, _, dim = q.shape
99+
100+
if len(rel_pos_lst) == 2:
101+
q_h, q_w = q_size[:2]
102+
k_h, k_w = k_size[:2]
103+
r_q = q.reshape(batch, q_h, q_w, dim)
104+
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
105+
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)
106+
107+
attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
108+
batch, q_h * q_w, k_h * k_w
109+
)
110+
elif len(rel_pos_lst) == 3:
111+
q_h, q_w, q_d = q_size[:3]
112+
k_h, k_w, k_d = k_size[:3]
113+
114+
rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])
115+
116+
r_q = q.reshape(batch, q_h, q_w, q_d, dim)
117+
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
118+
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
119+
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)
120+
121+
attn = (
122+
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
123+
+ rel_h[:, :, :, :, None, None]
124+
+ rel_w[:, :, :, None, :, None]
125+
+ rel_d[:, :, :, None, None, :]
126+
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
127+
128+
return attn
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# you may not use this file except in compliance with the License.
2+
# You may obtain a copy of the License at
3+
# http://www.apache.org/licenses/LICENSE-2.0
4+
# Unless required by applicable law or agreed to in writing, software
5+
# distributed under the License is distributed on an "AS IS" BASIS,
6+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
7+
# See the License for the specific language governing permissions and
8+
# limitations under the License.
9+
10+
from __future__ import annotations
11+
12+
from typing import Iterable, Tuple
13+
14+
import torch
15+
from torch import nn
16+
17+
from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
18+
from monai.utils.misc import ensure_tuple_size
19+
20+
21+
class DecomposedRelativePosEmbedding(nn.Module):
22+
def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
23+
"""
24+
Args:
25+
s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
26+
c_dim (int): channel dimension
27+
num_heads(int): number of attention heads
28+
"""
29+
super().__init__()
30+
31+
# validate inputs
32+
if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
33+
raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")
34+
35+
self.s_input_dims = s_input_dims
36+
self.c_dim = c_dim
37+
self.num_heads = num_heads
38+
self.rel_pos_arr = nn.ParameterList(
39+
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
40+
)
41+
42+
def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
43+
""""""
44+
batch = x.shape[0]
45+
h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)
46+
47+
att_mat = add_decomposed_rel_pos(
48+
att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
49+
q.contiguous().view(batch * self.num_heads, h * w * d, -1),
50+
self.rel_pos_arr,
51+
(h, w) if d == 1 else (h, w, d),
52+
(h, w) if d == 1 else (h, w, d),
53+
)
54+
55+
att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
56+
return att_mat

monai/networks/blocks/selfattention.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Optional, Tuple
15+
1416
import torch
1517
import torch.nn as nn
1618

19+
from monai.networks.layers.utils import get_rel_pos_embedding_layer
1720
from monai.utils import optional_import
1821

1922
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
@@ -23,6 +26,7 @@ class SABlock(nn.Module):
2326
"""
2427
A self-attention block, based on: "Dosovitskiy et al.,
2528
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
29+
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
2630
"""
2731

2832
def __init__(
@@ -32,13 +36,19 @@ def __init__(
3236
dropout_rate: float = 0.0,
3337
qkv_bias: bool = False,
3438
save_attn: bool = False,
39+
rel_pos_embedding: Optional[str] = None,
40+
input_size: Optional[Tuple] = None,
3541
) -> None:
3642
"""
3743
Args:
3844
hidden_size (int): dimension of hidden layer.
3945
num_heads (int): number of attention heads.
4046
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
4147
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
48+
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
49+
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
50+
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
51+
positional parameter size.
4252
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
4353
4454
"""
@@ -62,11 +72,30 @@ def __init__(
6272
self.scale = self.head_dim**-0.5
6373
self.save_attn = save_attn
6474
self.att_mat = torch.Tensor()
75+
self.rel_positional_embedding = (
76+
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
77+
if rel_pos_embedding is not None
78+
else None
79+
)
80+
self.input_size = input_size
81+
82+
def forward(self, x: torch.Tensor):
83+
"""
84+
Args:
85+
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
6586
66-
def forward(self, x):
87+
Return:
88+
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
89+
"""
6790
output = self.input_rearrange(self.qkv(x))
6891
q, k, v = output[0], output[1], output[2]
69-
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
92+
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
93+
94+
# apply relative positional embedding if defined
95+
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat
96+
97+
att_mat = att_mat.softmax(dim=-1)
98+
7099
if self.save_attn:
71100
# no gradients and new tensor;
72101
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

monai/networks/layers/factories.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def use_factory(fact_args):
7070
from monai.networks.utils import has_nvfuser_instance_norm
7171
from monai.utils import ComponentStore, look_up_option, optional_import
7272

73-
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
73+
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]
7474

7575

7676
class LayerFactory(ComponentStore):
@@ -201,6 +201,10 @@ def split_args(args):
201201
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
202202
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
203203
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
204+
RelPosEmbedding = LayerFactory(
205+
name="Relative positional embedding layers",
206+
description="Factory for creating relative positional embedding factory",
207+
)
204208

205209

206210
@Dropout.factory_function("dropout")
@@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
468472
"""
469473
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
470474
return types[dim - 1]
475+
476+
477+
@RelPosEmbedding.factory_function("decomposed")
478+
def decomposed_rel_pos_embedding() -> type[nn.Module]:
479+
from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding
480+
481+
return DecomposedRelativePosEmbedding

monai/networks/layers/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Optional
15+
1416
import torch.nn
1517

16-
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
18+
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
1719
from monai.utils import has_option
1820

1921
__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
@@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
124126
pool_name, pool_args = split_args(name)
125127
pool_type = Pool[pool_name, spatial_dims]
126128
return pool_type(**pool_args)
129+
130+
131+
def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
132+
embedding_name, embedding_args = split_args(name)
133+
embedding_type = RelPosEmbedding[embedding_name]
134+
# create a dictionary with the default values which can be overridden by embedding_args
135+
kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
136+
# filter out unused argument names
137+
kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}
138+
139+
return embedding_type(**kw_args)

tests/test_selfattention.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from monai.networks import eval_mode
2222
from monai.networks.blocks.selfattention import SABlock
23+
from monai.networks.layers.factories import RelPosEmbedding
2324
from monai.utils import optional_import
2425

2526
einops, has_einops = optional_import("einops")
@@ -28,12 +29,20 @@
2829
for dropout_rate in np.linspace(0, 1, 4):
2930
for hidden_size in [360, 480, 600, 768]:
3031
for num_heads in [4, 6, 8, 12]:
31-
test_case = [
32-
{"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate},
33-
(2, 512, hidden_size),
34-
(2, 512, hidden_size),
35-
]
36-
TEST_CASE_SABLOCK.append(test_case)
32+
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
33+
for input_size in [(16, 32), (8, 8, 8)]:
34+
test_case = [
35+
{
36+
"hidden_size": hidden_size,
37+
"num_heads": num_heads,
38+
"dropout_rate": dropout_rate,
39+
"rel_pos_embedding": rel_pos_embedding,
40+
"input_size": input_size,
41+
},
42+
(2, 512, hidden_size),
43+
(2, 512, hidden_size),
44+
]
45+
TEST_CASE_SABLOCK.append(test_case)
3746

3847

3948
class TestResBlock(unittest.TestCase):

0 commit comments

Comments
 (0)