Skip to content

Commit c61c6ac

Browse files
authored
6676 port generative networks transformer (#7300)
Towards #6676 . ### Description Adds a simple decoder-only transformer architecture. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham <[email protected]>
1 parent b3fdfdd commit c61c6ac

File tree

4 files changed

+393
-0
lines changed

4 files changed

+393
-0
lines changed

docs/source/networks.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,11 @@ Nets
613613
.. autoclass:: VarAutoEncoder
614614
:members:
615615

616+
`DecoderOnlyTransformer`
617+
~~~~~~~~~~~~~~~~~~~~~~~~
618+
.. autoclass:: DecoderOnlyTransformer
619+
:members:
620+
616621
`ViT`
617622
~~~~~
618623
.. autoclass:: ViT

monai/networks/nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
107107
from .torchvision_fc import TorchVisionFCModel
108108
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
109+
from .transformer import DecoderOnlyTransformer
109110
from .unet import UNet, Unet
110111
from .unetr import UNETR
111112
from .varautoencoder import VarAutoEncoder

monai/networks/nets/transformer.py

Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import math
15+
16+
import torch
17+
import torch.nn as nn
18+
import torch.nn.functional as F
19+
20+
from monai.networks.blocks.mlp import MLPBlock
21+
from monai.utils import optional_import
22+
23+
xops, has_xformers = optional_import("xformers.ops")
24+
__all__ = ["DecoderOnlyTransformer"]
25+
26+
27+
class _SABlock(nn.Module):
28+
"""
29+
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
30+
use of this block as support is not guaranteed. For more information see:
31+
https://github.com/Project-MONAI/MONAI/issues/7227
32+
33+
A self-attention block, based on: "Dosovitskiy et al.,
34+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
35+
36+
Args:
37+
hidden_size: dimension of hidden layer.
38+
num_heads: number of attention heads.
39+
dropout_rate: dropout ratio. Defaults to no dropout.
40+
qkv_bias: bias term for the qkv linear layer.
41+
causal: whether to use causal attention.
42+
sequence_length: if causal is True, it is necessary to specify the sequence length.
43+
with_cross_attention: Whether to use cross attention for conditioning.
44+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
45+
"""
46+
47+
def __init__(
48+
self,
49+
hidden_size: int,
50+
num_heads: int,
51+
dropout_rate: float = 0.0,
52+
qkv_bias: bool = False,
53+
causal: bool = False,
54+
sequence_length: int | None = None,
55+
with_cross_attention: bool = False,
56+
use_flash_attention: bool = False,
57+
) -> None:
58+
super().__init__()
59+
self.hidden_size = hidden_size
60+
self.num_heads = num_heads
61+
self.head_dim = hidden_size // num_heads
62+
self.scale = 1.0 / math.sqrt(self.head_dim)
63+
self.causal = causal
64+
self.sequence_length = sequence_length
65+
self.with_cross_attention = with_cross_attention
66+
self.use_flash_attention = use_flash_attention
67+
68+
if not (0 <= dropout_rate <= 1):
69+
raise ValueError("dropout_rate should be between 0 and 1.")
70+
self.dropout_rate = dropout_rate
71+
72+
if hidden_size % num_heads != 0:
73+
raise ValueError("hidden size should be divisible by num_heads.")
74+
75+
if causal and sequence_length is None:
76+
raise ValueError("sequence_length is necessary for causal attention.")
77+
78+
if use_flash_attention and not has_xformers:
79+
raise ValueError("use_flash_attention is True but xformers is not installed.")
80+
81+
# key, query, value projections
82+
self.to_q = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
83+
self.to_k = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
84+
self.to_v = nn.Linear(hidden_size, hidden_size, bias=qkv_bias)
85+
86+
# regularization
87+
self.drop_weights = nn.Dropout(dropout_rate)
88+
self.drop_output = nn.Dropout(dropout_rate)
89+
90+
# output projection
91+
self.out_proj = nn.Linear(hidden_size, hidden_size)
92+
93+
if causal and sequence_length is not None:
94+
# causal mask to ensure that attention is only applied to the left in the input sequence
95+
self.register_buffer(
96+
"causal_mask",
97+
torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
98+
)
99+
self.causal_mask: torch.Tensor
100+
101+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
102+
b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
103+
104+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
105+
query = self.to_q(x)
106+
107+
kv = context if context is not None else x
108+
_, kv_t, _ = kv.size()
109+
key = self.to_k(kv)
110+
value = self.to_v(kv)
111+
112+
query = query.view(b, t, self.num_heads, c // self.num_heads) # (b, t, nh, hs)
113+
key = key.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
114+
value = value.view(b, kv_t, self.num_heads, c // self.num_heads) # (b, kv_t, nh, hs)
115+
y: torch.Tensor
116+
if self.use_flash_attention:
117+
query = query.contiguous()
118+
key = key.contiguous()
119+
value = value.contiguous()
120+
y = xops.memory_efficient_attention(
121+
query=query,
122+
key=key,
123+
value=value,
124+
scale=self.scale,
125+
p=self.dropout_rate,
126+
attn_bias=xops.LowerTriangularMask() if self.causal else None,
127+
)
128+
129+
else:
130+
query = query.transpose(1, 2) # (b, nh, t, hs)
131+
key = key.transpose(1, 2) # (b, nh, kv_t, hs)
132+
value = value.transpose(1, 2) # (b, nh, kv_t, hs)
133+
134+
# manual implementation of attention
135+
query = query * self.scale
136+
attention_scores = query @ key.transpose(-2, -1)
137+
138+
if self.causal:
139+
attention_scores = attention_scores.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
140+
141+
attention_probs = F.softmax(attention_scores, dim=-1)
142+
attention_probs = self.drop_weights(attention_probs)
143+
y = attention_probs @ value # (b, nh, t, kv_t) x (b, nh, kv_t, hs) -> (b, nh, t, hs)
144+
145+
y = y.transpose(1, 2) # (b, nh, t, hs) -> (b, t, nh, hs)
146+
147+
y = y.contiguous().view(b, t, c) # re-assemble all head outputs side by side
148+
149+
y = self.out_proj(y)
150+
y = self.drop_output(y)
151+
return y
152+
153+
154+
class _TransformerBlock(nn.Module):
155+
"""
156+
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
157+
use of this block as support is not guaranteed. For more information see:
158+
https://github.com/Project-MONAI/MONAI/issues/7227
159+
160+
A transformer block, based on: "Dosovitskiy et al.,
161+
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
162+
163+
Args:
164+
hidden_size: dimension of hidden layer.
165+
mlp_dim: dimension of feedforward layer.
166+
num_heads: number of attention heads.
167+
dropout_rate: faction of the input units to drop.
168+
qkv_bias: apply bias term for the qkv linear layer
169+
causal: whether to use causal attention.
170+
sequence_length: if causal is True, it is necessary to specify the sequence length.
171+
with_cross_attention: Whether to use cross attention for conditioning.
172+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
173+
"""
174+
175+
def __init__(
176+
self,
177+
hidden_size: int,
178+
mlp_dim: int,
179+
num_heads: int,
180+
dropout_rate: float = 0.0,
181+
qkv_bias: bool = False,
182+
causal: bool = False,
183+
sequence_length: int | None = None,
184+
with_cross_attention: bool = False,
185+
use_flash_attention: bool = False,
186+
) -> None:
187+
self.with_cross_attention = with_cross_attention
188+
super().__init__()
189+
190+
if not (0 <= dropout_rate <= 1):
191+
raise ValueError("dropout_rate should be between 0 and 1.")
192+
193+
if hidden_size % num_heads != 0:
194+
raise ValueError("hidden_size should be divisible by num_heads.")
195+
196+
self.norm1 = nn.LayerNorm(hidden_size)
197+
self.attn = _SABlock(
198+
hidden_size=hidden_size,
199+
num_heads=num_heads,
200+
dropout_rate=dropout_rate,
201+
qkv_bias=qkv_bias,
202+
causal=causal,
203+
sequence_length=sequence_length,
204+
use_flash_attention=use_flash_attention,
205+
)
206+
207+
if self.with_cross_attention:
208+
self.norm2 = nn.LayerNorm(hidden_size)
209+
self.cross_attn = _SABlock(
210+
hidden_size=hidden_size,
211+
num_heads=num_heads,
212+
dropout_rate=dropout_rate,
213+
qkv_bias=qkv_bias,
214+
with_cross_attention=with_cross_attention,
215+
causal=False,
216+
use_flash_attention=use_flash_attention,
217+
)
218+
self.norm3 = nn.LayerNorm(hidden_size)
219+
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
220+
221+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
222+
x = x + self.attn(self.norm1(x))
223+
if self.with_cross_attention:
224+
x = x + self.cross_attn(self.norm2(x), context=context)
225+
x = x + self.mlp(self.norm3(x))
226+
return x
227+
228+
229+
class AbsolutePositionalEmbedding(nn.Module):
230+
"""Absolute positional embedding.
231+
232+
Args:
233+
max_seq_len: Maximum sequence length.
234+
embedding_dim: Dimensionality of the embedding.
235+
"""
236+
237+
def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
238+
super().__init__()
239+
self.max_seq_len = max_seq_len
240+
self.embedding_dim = embedding_dim
241+
self.embedding = nn.Embedding(max_seq_len, embedding_dim)
242+
243+
def forward(self, x: torch.Tensor) -> torch.Tensor:
244+
batch_size, seq_len = x.size()
245+
positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
246+
embedding: torch.Tensor = self.embedding(positions)
247+
return embedding
248+
249+
250+
class DecoderOnlyTransformer(nn.Module):
251+
"""Decoder-only (Autoregressive) Transformer model.
252+
253+
Args:
254+
num_tokens: Number of tokens in the vocabulary.
255+
max_seq_len: Maximum sequence length.
256+
attn_layers_dim: Dimensionality of the attention layers.
257+
attn_layers_depth: Number of attention layers.
258+
attn_layers_heads: Number of attention heads.
259+
with_cross_attention: Whether to use cross attention for conditioning.
260+
embedding_dropout_rate: Dropout rate for the embedding.
261+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
262+
"""
263+
264+
def __init__(
265+
self,
266+
num_tokens: int,
267+
max_seq_len: int,
268+
attn_layers_dim: int,
269+
attn_layers_depth: int,
270+
attn_layers_heads: int,
271+
with_cross_attention: bool = False,
272+
embedding_dropout_rate: float = 0.0,
273+
use_flash_attention: bool = False,
274+
) -> None:
275+
super().__init__()
276+
self.num_tokens = num_tokens
277+
self.max_seq_len = max_seq_len
278+
self.attn_layers_dim = attn_layers_dim
279+
self.attn_layers_depth = attn_layers_depth
280+
self.attn_layers_heads = attn_layers_heads
281+
self.with_cross_attention = with_cross_attention
282+
283+
self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
284+
self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
285+
self.embedding_dropout = nn.Dropout(embedding_dropout_rate)
286+
287+
self.blocks = nn.ModuleList(
288+
[
289+
_TransformerBlock(
290+
hidden_size=attn_layers_dim,
291+
mlp_dim=attn_layers_dim * 4,
292+
num_heads=attn_layers_heads,
293+
dropout_rate=0.0,
294+
qkv_bias=False,
295+
causal=True,
296+
sequence_length=max_seq_len,
297+
with_cross_attention=with_cross_attention,
298+
use_flash_attention=use_flash_attention,
299+
)
300+
for _ in range(attn_layers_depth)
301+
]
302+
)
303+
304+
self.to_logits = nn.Linear(attn_layers_dim, num_tokens)
305+
306+
def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
307+
tok_emb = self.token_embeddings(x)
308+
pos_emb = self.position_embeddings(x)
309+
x = self.embedding_dropout(tok_emb + pos_emb)
310+
311+
for block in self.blocks:
312+
x = block(x, context=context)
313+
logits: torch.Tensor = self.to_logits(x)
314+
return logits

0 commit comments

Comments
 (0)