Skip to content

Commit 67aa4cf

Browse files
ahatamizmonai-bot
andauthored
add classification support for ViT Model (#2861)
* add classification support for ViT Model Signed-off-by: ahatamizadeh <[email protected]> * add classification support for ViT Model Signed-off-by: ahatamizadeh <[email protected]> * add classification support for ViT Model Signed-off-by: ahatamizadeh <[email protected]> * [MONAI] python code formatting Signed-off-by: monai-bot <[email protected]> Co-authored-by: monai-bot <[email protected]>
1 parent c99bd41 commit 67aa4cf

File tree

4 files changed

+44
-34
lines changed

4 files changed

+44
-34
lines changed

monai/apps/deepedit/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
from monai.config import KeysCollection
88
from monai.transforms.transform import MapTransform, Randomizable, Transform
9+
from monai.utils import optional_import
910

1011
logger = logging.getLogger(__name__)
1112

12-
from monai.utils import optional_import
1313

1414
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
1515

monai/networks/nets/unetr.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __init__(
3434
hidden_size: int = 768,
3535
mlp_dim: int = 3072,
3636
num_heads: int = 12,
37-
pos_embed: str = "perceptron",
37+
pos_embed: str = "conv",
3838
norm_name: Union[Tuple, str] = "instance",
39-
conv_block: bool = False,
39+
conv_block: bool = True,
4040
res_block: bool = True,
4141
dropout_rate: float = 0.0,
4242
spatial_dims: int = 3,
@@ -59,13 +59,13 @@ def __init__(
5959
6060
Examples::
6161
62-
# for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm
62+
# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
6363
>>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch')
6464
65-
# for single channel input 4-channel output with patch size of (96,96), feature size of 32 and batch norm
65+
# for single channel input 4-channel output with image size of (96,96), feature size of 32 and batch norm
6666
>>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)
6767
68-
# for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm
68+
# for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
6969
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
7070
7171
"""

monai/networks/nets/vit.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from typing import Sequence, Union
1414

15+
import torch
1516
import torch.nn as nn
1617

1718
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
@@ -33,7 +34,7 @@ def __init__(
3334
mlp_dim: int = 3072,
3435
num_layers: int = 12,
3536
num_heads: int = 12,
36-
pos_embed: str = "perceptron",
37+
pos_embed: str = "conv",
3738
classification: bool = False,
3839
num_classes: int = 2,
3940
dropout_rate: float = 0.0,
@@ -56,12 +57,15 @@ def __init__(
5657
5758
Examples::
5859
59-
# for single channel input with patch size of (96,96,96), conv position embedding and segmentation backbone
60+
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
6061
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
6162
62-
# for 3-channel with patch size of (128,128,128), 24 layers and classification backbone
63+
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
6364
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
6465
66+
# for 3-channel with image size of (224,224), 12 layers and classification backbone
67+
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
68+
6569
"""
6670

6771
super(ViT, self).__init__()
@@ -88,10 +92,14 @@ def __init__(
8892
)
8993
self.norm = nn.LayerNorm(hidden_size)
9094
if self.classification:
91-
self.classification_head = nn.Linear(hidden_size, num_classes)
95+
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
96+
self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
9297

9398
def forward(self, x):
9499
x = self.patch_embedding(x)
100+
if self.classification:
101+
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
102+
x = torch.cat((cls_token, x), dim=1)
95103
hidden_states_out = []
96104
for blk in self.blocks:
97105
x = blk(x)

tests/test_vit.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,30 +26,32 @@
2626
for num_heads in [12]:
2727
for mlp_dim in [3072]:
2828
for num_layers in [4]:
29-
for num_classes in [2]:
29+
for num_classes in [8]:
3030
for pos_embed in ["conv"]:
31-
# for classification in [False, True]: # TODO: test classification
32-
for nd in (2, 3):
33-
test_case = [
34-
{
35-
"in_channels": in_channels,
36-
"img_size": (img_size,) * nd,
37-
"patch_size": (patch_size,) * nd,
38-
"hidden_size": hidden_size,
39-
"mlp_dim": mlp_dim,
40-
"num_layers": num_layers,
41-
"num_heads": num_heads,
42-
"pos_embed": pos_embed,
43-
"classification": False,
44-
"num_classes": num_classes,
45-
"dropout_rate": dropout_rate,
46-
},
47-
(2, in_channels, *([img_size] * nd)),
48-
(2, (img_size // patch_size) ** nd, hidden_size),
49-
]
50-
if nd == 2:
51-
test_case[0]["spatial_dims"] = 2 # type: ignore
52-
TEST_CASE_Vit.append(test_case)
31+
for classification in [False, True]:
32+
for nd in (2, 3):
33+
test_case = [
34+
{
35+
"in_channels": in_channels,
36+
"img_size": (img_size,) * nd,
37+
"patch_size": (patch_size,) * nd,
38+
"hidden_size": hidden_size,
39+
"mlp_dim": mlp_dim,
40+
"num_layers": num_layers,
41+
"num_heads": num_heads,
42+
"pos_embed": pos_embed,
43+
"classification": classification,
44+
"num_classes": num_classes,
45+
"dropout_rate": dropout_rate,
46+
},
47+
(2, in_channels, *([img_size] * nd)),
48+
(2, (img_size // patch_size) ** nd, hidden_size),
49+
]
50+
if nd == 2:
51+
test_case[0]["spatial_dims"] = 2 # type: ignore
52+
if test_case[0]["classification"]: # type: ignore
53+
test_case[2] = (2, test_case[0]["num_classes"]) # type: ignore
54+
TEST_CASE_Vit.append(test_case)
5355

5456

5557
class TestPatchEmbeddingBlock(unittest.TestCase):
@@ -113,7 +115,7 @@ def test_ill_arg(self):
113115
num_layers=12,
114116
num_heads=8,
115117
pos_embed="perceptron",
116-
classification=False,
118+
classification=True,
117119
dropout_rate=0.3,
118120
)
119121

0 commit comments

Comments
 (0)