13
13
14
14
import gc
15
15
import logging
16
- from typing import TYPE_CHECKING , Sequence , cast
16
+ from typing import Sequence
17
17
18
18
import torch
19
19
import torch .nn as nn
20
20
import torch .nn .functional as F
21
21
22
22
from monai .networks .blocks import Convolution
23
- from monai .utils import optional_import
23
+ from monai .networks .blocks .spatialattention import SpatialAttentionBlock
24
+ from monai .networks .nets .autoencoderkl import AEKLResBlock , AutoencoderKL
24
25
from monai .utils .type_conversion import convert_to_tensor
25
26
26
- AttentionBlock , has_attentionblock = optional_import ("generative.networks.nets.autoencoderkl" , name = "AttentionBlock" )
27
- AutoencoderKL , has_autoencoderkl = optional_import ("generative.networks.nets.autoencoderkl" , name = "AutoencoderKL" )
28
- ResBlock , has_resblock = optional_import ("generative.networks.nets.autoencoderkl" , name = "ResBlock" )
29
-
30
- if TYPE_CHECKING :
31
- from generative .networks .nets .autoencoderkl import AutoencoderKL as AutoencoderKLType
32
- else :
33
- AutoencoderKLType = cast (type , AutoencoderKL )
34
-
35
27
# Set up logging configuration
36
28
logger = logging .getLogger (__name__ )
37
29
@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
518
510
in_channels: Number of input channels.
519
511
num_channels: Sequence of block output channels.
520
512
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
521
- num_res_blocks: Number of residual blocks (see ResBlock ) per level.
513
+ num_res_blocks: Number of residual blocks (see AEKLResBlock ) per level.
522
514
norm_num_groups: Number of groups for the group norm layers.
523
515
norm_eps: Epsilon for the normalization.
524
516
attention_levels: Indicate which level from num_channels contain an attention block.
525
517
with_nonlocal_attn: If True, use non-local attention block.
518
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
519
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
526
520
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
527
521
num_splits: Number of splits for the input tensor.
528
522
dim_split: Dimension of splitting for the input tensor.
@@ -547,6 +541,8 @@ def __init__(
547
541
print_info : bool = False ,
548
542
save_mem : bool = True ,
549
543
with_nonlocal_attn : bool = True ,
544
+ include_fc : bool = False ,
545
+ use_combined_linear : bool = False ,
550
546
use_flash_attention : bool = False ,
551
547
) -> None :
552
548
super ().__init__ ()
@@ -603,11 +599,13 @@ def __init__(
603
599
input_channel = output_channel
604
600
if attention_levels [i ]:
605
601
blocks .append (
606
- AttentionBlock (
602
+ SpatialAttentionBlock (
607
603
spatial_dims = spatial_dims ,
608
604
num_channels = input_channel ,
609
605
norm_num_groups = norm_num_groups ,
610
606
norm_eps = norm_eps ,
607
+ include_fc = include_fc ,
608
+ use_combined_linear = use_combined_linear ,
611
609
use_flash_attention = use_flash_attention ,
612
610
)
613
611
)
@@ -626,7 +624,7 @@ def __init__(
626
624
627
625
if with_nonlocal_attn :
628
626
blocks .append (
629
- ResBlock (
627
+ AEKLResBlock (
630
628
spatial_dims = spatial_dims ,
631
629
in_channels = num_channels [- 1 ],
632
630
norm_num_groups = norm_num_groups ,
@@ -636,16 +634,18 @@ def __init__(
636
634
)
637
635
638
636
blocks .append (
639
- AttentionBlock (
637
+ SpatialAttentionBlock (
640
638
spatial_dims = spatial_dims ,
641
639
num_channels = num_channels [- 1 ],
642
640
norm_num_groups = norm_num_groups ,
643
641
norm_eps = norm_eps ,
642
+ include_fc = include_fc ,
643
+ use_combined_linear = use_combined_linear ,
644
644
use_flash_attention = use_flash_attention ,
645
645
)
646
646
)
647
647
blocks .append (
648
- ResBlock (
648
+ AEKLResBlock (
649
649
spatial_dims = spatial_dims ,
650
650
in_channels = num_channels [- 1 ],
651
651
norm_num_groups = norm_num_groups ,
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
699
699
num_channels: Sequence of block output channels.
700
700
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
701
701
out_channels: Number of output channels.
702
- num_res_blocks: Number of residual blocks (see ResBlock ) per level.
702
+ num_res_blocks: Number of residual blocks (see AEKLResBlock ) per level.
703
703
norm_num_groups: Number of groups for the group norm layers.
704
704
norm_eps: Epsilon for the normalization.
705
705
attention_levels: Indicate which level from num_channels contain an attention block.
706
706
with_nonlocal_attn: If True, use non-local attention block.
707
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
708
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
707
709
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
708
710
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
709
711
num_splits: Number of splits for the input tensor.
@@ -729,6 +731,8 @@ def __init__(
729
731
print_info : bool = False ,
730
732
save_mem : bool = True ,
731
733
with_nonlocal_attn : bool = True ,
734
+ include_fc : bool = False ,
735
+ use_combined_linear : bool = False ,
732
736
use_flash_attention : bool = False ,
733
737
use_convtranspose : bool = False ,
734
738
) -> None :
@@ -758,7 +762,7 @@ def __init__(
758
762
759
763
if with_nonlocal_attn :
760
764
blocks .append (
761
- ResBlock (
765
+ AEKLResBlock (
762
766
spatial_dims = spatial_dims ,
763
767
in_channels = reversed_block_out_channels [0 ],
764
768
norm_num_groups = norm_num_groups ,
@@ -767,16 +771,18 @@ def __init__(
767
771
)
768
772
)
769
773
blocks .append (
770
- AttentionBlock (
774
+ SpatialAttentionBlock (
771
775
spatial_dims = spatial_dims ,
772
776
num_channels = reversed_block_out_channels [0 ],
773
777
norm_num_groups = norm_num_groups ,
774
778
norm_eps = norm_eps ,
779
+ include_fc = include_fc ,
780
+ use_combined_linear = use_combined_linear ,
775
781
use_flash_attention = use_flash_attention ,
776
782
)
777
783
)
778
784
blocks .append (
779
- ResBlock (
785
+ AEKLResBlock (
780
786
spatial_dims = spatial_dims ,
781
787
in_channels = reversed_block_out_channels [0 ],
782
788
norm_num_groups = norm_num_groups ,
@@ -812,11 +818,13 @@ def __init__(
812
818
813
819
if reversed_attention_levels [i ]:
814
820
blocks .append (
815
- AttentionBlock (
821
+ SpatialAttentionBlock (
816
822
spatial_dims = spatial_dims ,
817
823
num_channels = block_in_ch ,
818
824
norm_num_groups = norm_num_groups ,
819
825
norm_eps = norm_eps ,
826
+ include_fc = include_fc ,
827
+ use_combined_linear = use_combined_linear ,
820
828
use_flash_attention = use_flash_attention ,
821
829
)
822
830
)
@@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
870
878
return x
871
879
872
880
873
- class AutoencoderKlMaisi (AutoencoderKLType ):
881
+ class AutoencoderKlMaisi (AutoencoderKL ):
874
882
"""
875
883
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
876
884
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
886
894
norm_eps: Epsilon for the normalization.
887
895
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
888
896
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
897
+ include_fc: whether to include the final linear layer. Default to False.
898
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
889
899
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
890
900
use_checkpointing: If True, use activation checkpointing.
891
901
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
@@ -909,6 +919,8 @@ def __init__(
909
919
norm_eps : float = 1e-6 ,
910
920
with_encoder_nonlocal_attn : bool = False ,
911
921
with_decoder_nonlocal_attn : bool = False ,
922
+ include_fc : bool = False ,
923
+ use_combined_linear : bool = False ,
912
924
use_flash_attention : bool = False ,
913
925
use_checkpointing : bool = False ,
914
926
use_convtranspose : bool = False ,
@@ -930,12 +942,14 @@ def __init__(
930
942
norm_eps ,
931
943
with_encoder_nonlocal_attn ,
932
944
with_decoder_nonlocal_attn ,
933
- use_flash_attention ,
934
945
use_checkpointing ,
935
946
use_convtranspose ,
947
+ include_fc ,
948
+ use_combined_linear ,
949
+ use_flash_attention ,
936
950
)
937
951
938
- self .encoder = MaisiEncoder (
952
+ self .encoder : nn . Module = MaisiEncoder (
939
953
spatial_dims = spatial_dims ,
940
954
in_channels = in_channels ,
941
955
num_channels = num_channels ,
@@ -945,6 +959,8 @@ def __init__(
945
959
norm_eps = norm_eps ,
946
960
attention_levels = attention_levels ,
947
961
with_nonlocal_attn = with_encoder_nonlocal_attn ,
962
+ include_fc = include_fc ,
963
+ use_combined_linear = use_combined_linear ,
948
964
use_flash_attention = use_flash_attention ,
949
965
num_splits = num_splits ,
950
966
dim_split = dim_split ,
@@ -953,7 +969,7 @@ def __init__(
953
969
save_mem = save_mem ,
954
970
)
955
971
956
- self .decoder = MaisiDecoder (
972
+ self .decoder : nn . Module = MaisiDecoder (
957
973
spatial_dims = spatial_dims ,
958
974
num_channels = num_channels ,
959
975
in_channels = latent_channels ,
@@ -963,6 +979,8 @@ def __init__(
963
979
norm_eps = norm_eps ,
964
980
attention_levels = attention_levels ,
965
981
with_nonlocal_attn = with_decoder_nonlocal_attn ,
982
+ include_fc = include_fc ,
983
+ use_combined_linear = use_combined_linear ,
966
984
use_flash_attention = use_flash_attention ,
967
985
use_convtranspose = use_convtranspose ,
968
986
num_splits = num_splits ,
0 commit comments