11
11
12
12
13
13
import math
14
- from typing import Tuple , Union
14
+ from typing import Sequence , Union
15
15
16
+ import numpy as np
16
17
import torch
17
18
import torch .nn as nn
18
19
19
- from monai .utils import optional_import
20
+ from monai .networks .layers import Conv
21
+ from monai .utils import ensure_tuple_rep , optional_import
22
+ from monai .utils .module import look_up_option
20
23
21
24
Rearrange , _ = optional_import ("einops.layers.torch" , name = "Rearrange" )
25
+ SUPPORTED_EMBEDDING_TYPES = {"conv" , "perceptron" }
22
26
23
27
24
28
class PatchEmbeddingBlock (nn .Module ):
25
29
"""
26
30
A patch embedding block, based on: "Dosovitskiy et al.,
27
31
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
32
+
33
+ Example::
34
+
35
+ >>> from monai.networks.blocks import PatchEmbeddingBlock
36
+ >>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
37
+
28
38
"""
29
39
30
40
def __init__ (
31
41
self ,
32
42
in_channels : int ,
33
- img_size : Tuple [ int , int , int ],
34
- patch_size : Tuple [ int , int , int ],
43
+ img_size : Union [ Sequence [ int ] , int ],
44
+ patch_size : Union [ Sequence [ int ] , int ],
35
45
hidden_size : int ,
36
46
num_heads : int ,
37
47
pos_embed : str ,
38
48
dropout_rate : float = 0.0 ,
49
+ spatial_dims : int = 3 ,
39
50
) -> None :
40
51
"""
41
52
Args:
@@ -46,47 +57,44 @@ def __init__(
46
57
num_heads: number of attention heads.
47
58
pos_embed: position embedding layer type.
48
59
dropout_rate: faction of the input units to drop.
60
+ spatial_dims: number of spatial dimensions.
61
+
49
62
50
63
"""
51
64
52
- super ().__init__ ()
65
+ super (PatchEmbeddingBlock , self ).__init__ ()
53
66
54
67
if not (0 <= dropout_rate <= 1 ):
55
- raise AssertionError ("dropout_rate should be between 0 and 1." )
68
+ raise ValueError ("dropout_rate should be between 0 and 1." )
56
69
57
70
if hidden_size % num_heads != 0 :
58
- raise AssertionError ("hidden size should be divisible by num_heads." )
71
+ raise ValueError ("hidden size should be divisible by num_heads." )
72
+
73
+ self .pos_embed = look_up_option (pos_embed , SUPPORTED_EMBEDDING_TYPES )
59
74
75
+ img_size = ensure_tuple_rep (img_size , spatial_dims )
76
+ patch_size = ensure_tuple_rep (patch_size , spatial_dims )
60
77
for m , p in zip (img_size , patch_size ):
61
78
if m < p :
62
- raise AssertionError ("patch_size should be smaller than img_size." )
79
+ raise ValueError ("patch_size should be smaller than img_size." )
80
+ if self .pos_embed == "perceptron" and m % p != 0 :
81
+ raise ValueError ("patch_size should be divisible by img_size for perceptron." )
82
+ self .n_patches = np .prod ([im_d // p_d for im_d , p_d in zip (img_size , patch_size )])
83
+ self .patch_dim = in_channels * np .prod (patch_size )
63
84
64
- if pos_embed not in ["conv" , "perceptron" ]:
65
- raise KeyError (f"Position embedding layer of type { pos_embed } is not supported." )
66
-
67
- if pos_embed == "perceptron" :
68
- if img_size [0 ] % patch_size [0 ] != 0 :
69
- raise AssertionError ("img_size should be divisible by patch_size for perceptron patch embedding." )
70
-
71
- self .n_patches = (
72
- (img_size [0 ] // patch_size [0 ]) * (img_size [1 ] // patch_size [1 ]) * (img_size [2 ] // patch_size [2 ])
73
- )
74
- self .patch_dim = in_channels * patch_size [0 ] * patch_size [1 ] * patch_size [2 ]
75
-
76
- self .pos_embed = pos_embed
77
- self .patch_embeddings : Union [nn .Conv3d , nn .Sequential ]
85
+ self .patch_embeddings : nn .Module
78
86
if self .pos_embed == "conv" :
79
- self .patch_embeddings = nn . Conv3d (
87
+ self .patch_embeddings = Conv [ Conv . CONV , spatial_dims ] (
80
88
in_channels = in_channels , out_channels = hidden_size , kernel_size = patch_size , stride = patch_size
81
89
)
82
90
elif self .pos_embed == "perceptron" :
91
+ # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
92
+ chars = (("h" , "p1" ), ("w" , "p2" ), ("d" , "p3" ))[:spatial_dims ]
93
+ from_chars = "b c " + " " .join (f"({ k } { v } )" for k , v in chars )
94
+ to_chars = f"b ({ ' ' .join ([c [0 ] for c in chars ])} ) ({ ' ' .join ([c [1 ] for c in chars ])} c)"
95
+ axes_len = {f"p{ i + 1 } " : p for i , p in enumerate (patch_size )}
83
96
self .patch_embeddings = nn .Sequential (
84
- Rearrange (
85
- "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" ,
86
- p1 = patch_size [0 ],
87
- p2 = patch_size [1 ],
88
- p3 = patch_size [2 ],
89
- ),
97
+ Rearrange (f"{ from_chars } -> { to_chars } " , ** axes_len ),
90
98
nn .Linear (self .patch_dim , hidden_size ),
91
99
)
92
100
self .position_embeddings = nn .Parameter (torch .zeros (1 , self .n_patches , hidden_size ))
@@ -121,12 +129,9 @@ def norm_cdf(x):
121
129
return tensor
122
130
123
131
def forward (self , x ):
132
+ x = self .patch_embeddings (x )
124
133
if self .pos_embed == "conv" :
125
- x = self .patch_embeddings (x )
126
- x = x .flatten (2 )
127
- x = x .transpose (- 1 , - 2 )
128
- elif self .pos_embed == "perceptron" :
129
- x = self .patch_embeddings (x )
134
+ x = x .flatten (2 ).transpose (- 1 , - 2 )
130
135
embeddings = x + self .position_embeddings
131
136
embeddings = self .dropout (embeddings )
132
137
return embeddings
0 commit comments