Skip to content

Commit 62425d7

Browse files
authored
2499 adds 2d/3d support of patchembedding, ViT, and unetr model (#2698)
* 2d/3d patchembedding Signed-off-by: Wenqi Li <[email protected]> * minor updates for selfattention Signed-off-by: Wenqi Li <[email protected]> * 2d vit Signed-off-by: Wenqi Li <[email protected]> * fixes type hint Signed-off-by: Wenqi Li <[email protected]> * update unetr Signed-off-by: Wenqi Li <[email protected]> * fixes unit test Signed-off-by: Wenqi Li <[email protected]>
1 parent 4298b14 commit 62425d7

14 files changed

+180
-182
lines changed

monai/networks/blocks/mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535
super().__init__()
3636

3737
if not (0 <= dropout_rate <= 1):
38-
raise AssertionError("dropout_rate should be between 0 and 1.")
38+
raise ValueError("dropout_rate should be between 0 and 1.")
3939

4040
self.linear1 = nn.Linear(hidden_size, mlp_dim)
4141
self.linear2 = nn.Linear(mlp_dim, hidden_size)

monai/networks/blocks/patchembedding.py

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -11,31 +11,42 @@
1111

1212

1313
import math
14-
from typing import Tuple, Union
14+
from typing import Sequence, Union
1515

16+
import numpy as np
1617
import torch
1718
import torch.nn as nn
1819

19-
from monai.utils import optional_import
20+
from monai.networks.layers import Conv
21+
from monai.utils import ensure_tuple_rep, optional_import
22+
from monai.utils.module import look_up_option
2023

2124
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
25+
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}
2226

2327

2428
class PatchEmbeddingBlock(nn.Module):
2529
"""
2630
A patch embedding block, based on: "Dosovitskiy et al.,
2731
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
32+
33+
Example::
34+
35+
>>> from monai.networks.blocks import PatchEmbeddingBlock
36+
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
37+
2838
"""
2939

3040
def __init__(
3141
self,
3242
in_channels: int,
33-
img_size: Tuple[int, int, int],
34-
patch_size: Tuple[int, int, int],
43+
img_size: Union[Sequence[int], int],
44+
patch_size: Union[Sequence[int], int],
3545
hidden_size: int,
3646
num_heads: int,
3747
pos_embed: str,
3848
dropout_rate: float = 0.0,
49+
spatial_dims: int = 3,
3950
) -> None:
4051
"""
4152
Args:
@@ -46,47 +57,44 @@ def __init__(
4657
num_heads: number of attention heads.
4758
pos_embed: position embedding layer type.
4859
dropout_rate: faction of the input units to drop.
60+
spatial_dims: number of spatial dimensions.
61+
4962
5063
"""
5164

52-
super().__init__()
65+
super(PatchEmbeddingBlock, self).__init__()
5366

5467
if not (0 <= dropout_rate <= 1):
55-
raise AssertionError("dropout_rate should be between 0 and 1.")
68+
raise ValueError("dropout_rate should be between 0 and 1.")
5669

5770
if hidden_size % num_heads != 0:
58-
raise AssertionError("hidden size should be divisible by num_heads.")
71+
raise ValueError("hidden size should be divisible by num_heads.")
72+
73+
self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)
5974

75+
img_size = ensure_tuple_rep(img_size, spatial_dims)
76+
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
6077
for m, p in zip(img_size, patch_size):
6178
if m < p:
62-
raise AssertionError("patch_size should be smaller than img_size.")
79+
raise ValueError("patch_size should be smaller than img_size.")
80+
if self.pos_embed == "perceptron" and m % p != 0:
81+
raise ValueError("patch_size should be divisible by img_size for perceptron.")
82+
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
83+
self.patch_dim = in_channels * np.prod(patch_size)
6384

64-
if pos_embed not in ["conv", "perceptron"]:
65-
raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.")
66-
67-
if pos_embed == "perceptron":
68-
if img_size[0] % patch_size[0] != 0:
69-
raise AssertionError("img_size should be divisible by patch_size for perceptron patch embedding.")
70-
71-
self.n_patches = (
72-
(img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1]) * (img_size[2] // patch_size[2])
73-
)
74-
self.patch_dim = in_channels * patch_size[0] * patch_size[1] * patch_size[2]
75-
76-
self.pos_embed = pos_embed
77-
self.patch_embeddings: Union[nn.Conv3d, nn.Sequential]
85+
self.patch_embeddings: nn.Module
7886
if self.pos_embed == "conv":
79-
self.patch_embeddings = nn.Conv3d(
87+
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
8088
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
8189
)
8290
elif self.pos_embed == "perceptron":
91+
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
92+
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
93+
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
94+
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
95+
axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
8396
self.patch_embeddings = nn.Sequential(
84-
Rearrange(
85-
"b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)",
86-
p1=patch_size[0],
87-
p2=patch_size[1],
88-
p3=patch_size[2],
89-
),
97+
Rearrange(f"{from_chars} -> {to_chars}", **axes_len),
9098
nn.Linear(self.patch_dim, hidden_size),
9199
)
92100
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
@@ -121,12 +129,9 @@ def norm_cdf(x):
121129
return tensor
122130

123131
def forward(self, x):
132+
x = self.patch_embeddings(x)
124133
if self.pos_embed == "conv":
125-
x = self.patch_embeddings(x)
126-
x = x.flatten(2)
127-
x = x.transpose(-1, -2)
128-
elif self.pos_embed == "perceptron":
129-
x = self.patch_embeddings(x)
134+
x = x.flatten(2).transpose(-1, -2)
130135
embeddings = x + self.position_embeddings
131136
embeddings = self.dropout(embeddings)
132137
return embeddings

monai/networks/blocks/selfattention.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from monai.utils import optional_import
1616

17-
einops, has_einops = optional_import("einops")
17+
einops, _ = optional_import("einops")
1818

1919

2020
class SABlock(nn.Module):
@@ -37,13 +37,13 @@ def __init__(
3737
3838
"""
3939

40-
super().__init__()
40+
super(SABlock, self).__init__()
4141

4242
if not (0 <= dropout_rate <= 1):
43-
raise AssertionError("dropout_rate should be between 0 and 1.")
43+
raise ValueError("dropout_rate should be between 0 and 1.")
4444

4545
if hidden_size % num_heads != 0:
46-
raise AssertionError("hidden size should be divisible by num_heads.")
46+
raise ValueError("hidden size should be divisible by num_heads.")
4747

4848
self.num_heads = num_heads
4949
self.out_proj = nn.Linear(hidden_size, hidden_size)
@@ -52,17 +52,13 @@ def __init__(
5252
self.drop_weights = nn.Dropout(dropout_rate)
5353
self.head_dim = hidden_size // num_heads
5454
self.scale = self.head_dim ** -0.5
55-
if has_einops:
56-
self.rearrange = einops.rearrange
57-
else:
58-
raise ValueError('"Requires einops.')
5955

6056
def forward(self, x):
61-
q, k, v = self.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
57+
q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
6258
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
6359
att_mat = self.drop_weights(att_mat)
6460
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
65-
x = self.rearrange(x, "b h l d -> b l (h d)")
61+
x = einops.rearrange(x, "b h l d -> b l (h d)")
6662
x = self.out_proj(x)
6763
x = self.drop_output(x)
6864
return x

monai/networks/blocks/transformerblock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,10 @@ def __init__(
4040
super().__init__()
4141

4242
if not (0 <= dropout_rate <= 1):
43-
raise AssertionError("dropout_rate should be between 0 and 1.")
43+
raise ValueError("dropout_rate should be between 0 and 1.")
4444

4545
if hidden_size % num_heads != 0:
46-
raise AssertionError("hidden size should be divisible by num_heads.")
46+
raise ValueError("hidden_size should be divisible by num_heads.")
4747

4848
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
4949
self.norm1 = nn.LayerNorm(hidden_size)

monai/networks/blocks/unetr_block.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ def __init__(
2828
self,
2929
spatial_dims: int,
3030
in_channels: int,
31-
out_channels: int, # type: ignore
31+
out_channels: int,
3232
kernel_size: Union[Sequence[int], int],
33-
stride: Union[Sequence[int], int],
3433
upsample_kernel_size: Union[Sequence[int], int],
3534
norm_name: Union[Tuple, str],
3635
res_block: bool = False,
@@ -41,7 +40,6 @@ def __init__(
4140
in_channels: number of input channels.
4241
out_channels: number of output channels.
4342
kernel_size: convolution kernel size.
44-
stride: convolution stride.
4543
upsample_kernel_size: convolution kernel size for transposed convolution layers.
4644
norm_name: feature normalization type and arguments.
4745
res_block: bool argument to determine if residual block is used.
@@ -148,7 +146,7 @@ def __init__(
148146
is_transposed=True,
149147
),
150148
UnetResBlock(
151-
spatial_dims=3,
149+
spatial_dims=spatial_dims,
152150
in_channels=out_channels,
153151
out_channels=out_channels,
154152
kernel_size=kernel_size,
@@ -173,7 +171,7 @@ def __init__(
173171
is_transposed=True,
174172
),
175173
UnetBasicBlock(
176-
spatial_dims=3,
174+
spatial_dims=spatial_dims,
177175
in_channels=out_channels,
178176
out_channels=out_channels,
179177
kernel_size=kernel_size,
@@ -257,5 +255,4 @@ def __init__(
257255
)
258256

259257
def forward(self, inp):
260-
out = self.layer(inp)
261-
return out
258+
return self.layer(inp)

0 commit comments

Comments
 (0)