Skip to content

Commit b622b7c

Browse files
guopengfpre-commit-ci[bot]KumoLiu
authored andcommitted
Refactor AutoencoderKlMaisi (Project-MONAI#7993)
Fixes Project-MONAI#7988 . ### Description Refactor AutoencoderKlMaisi to use monai core components. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Pengfei Guo <[email protected]> Signed-off-by: YunLiu <[email protected]> Signed-off-by: Pengfei Guo <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]>
1 parent 61ef0da commit b622b7c

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

monai/apps/generation/maisi/networks/autoencoderkl_maisi.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,17 @@
1313

1414
import gc
1515
import logging
16-
from typing import TYPE_CHECKING, Sequence, cast
16+
from typing import Sequence
1717

1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

2222
from monai.networks.blocks import Convolution
23-
from monai.utils import optional_import
23+
from monai.networks.blocks.spatialattention import SpatialAttentionBlock
24+
from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
2425
from monai.utils.type_conversion import convert_to_tensor
2526

26-
AttentionBlock, has_attentionblock = optional_import("generative.networks.nets.autoencoderkl", name="AttentionBlock")
27-
AutoencoderKL, has_autoencoderkl = optional_import("generative.networks.nets.autoencoderkl", name="AutoencoderKL")
28-
ResBlock, has_resblock = optional_import("generative.networks.nets.autoencoderkl", name="ResBlock")
29-
30-
if TYPE_CHECKING:
31-
from generative.networks.nets.autoencoderkl import AutoencoderKL as AutoencoderKLType
32-
else:
33-
AutoencoderKLType = cast(type, AutoencoderKL)
34-
3527
# Set up logging configuration
3628
logger = logging.getLogger(__name__)
3729

@@ -518,11 +510,13 @@ class MaisiEncoder(nn.Module):
518510
in_channels: Number of input channels.
519511
num_channels: Sequence of block output channels.
520512
out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
521-
num_res_blocks: Number of residual blocks (see ResBlock) per level.
513+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
522514
norm_num_groups: Number of groups for the group norm layers.
523515
norm_eps: Epsilon for the normalization.
524516
attention_levels: Indicate which level from num_channels contain an attention block.
525517
with_nonlocal_attn: If True, use non-local attention block.
518+
include_fc: whether to include the final linear layer in the attention block. Default to False.
519+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
526520
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
527521
num_splits: Number of splits for the input tensor.
528522
dim_split: Dimension of splitting for the input tensor.
@@ -547,6 +541,8 @@ def __init__(
547541
print_info: bool = False,
548542
save_mem: bool = True,
549543
with_nonlocal_attn: bool = True,
544+
include_fc: bool = False,
545+
use_combined_linear: bool = False,
550546
use_flash_attention: bool = False,
551547
) -> None:
552548
super().__init__()
@@ -603,11 +599,13 @@ def __init__(
603599
input_channel = output_channel
604600
if attention_levels[i]:
605601
blocks.append(
606-
AttentionBlock(
602+
SpatialAttentionBlock(
607603
spatial_dims=spatial_dims,
608604
num_channels=input_channel,
609605
norm_num_groups=norm_num_groups,
610606
norm_eps=norm_eps,
607+
include_fc=include_fc,
608+
use_combined_linear=use_combined_linear,
611609
use_flash_attention=use_flash_attention,
612610
)
613611
)
@@ -626,7 +624,7 @@ def __init__(
626624

627625
if with_nonlocal_attn:
628626
blocks.append(
629-
ResBlock(
627+
AEKLResBlock(
630628
spatial_dims=spatial_dims,
631629
in_channels=num_channels[-1],
632630
norm_num_groups=norm_num_groups,
@@ -636,16 +634,18 @@ def __init__(
636634
)
637635

638636
blocks.append(
639-
AttentionBlock(
637+
SpatialAttentionBlock(
640638
spatial_dims=spatial_dims,
641639
num_channels=num_channels[-1],
642640
norm_num_groups=norm_num_groups,
643641
norm_eps=norm_eps,
642+
include_fc=include_fc,
643+
use_combined_linear=use_combined_linear,
644644
use_flash_attention=use_flash_attention,
645645
)
646646
)
647647
blocks.append(
648-
ResBlock(
648+
AEKLResBlock(
649649
spatial_dims=spatial_dims,
650650
in_channels=num_channels[-1],
651651
norm_num_groups=norm_num_groups,
@@ -699,11 +699,13 @@ class MaisiDecoder(nn.Module):
699699
num_channels: Sequence of block output channels.
700700
in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
701701
out_channels: Number of output channels.
702-
num_res_blocks: Number of residual blocks (see ResBlock) per level.
702+
num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
703703
norm_num_groups: Number of groups for the group norm layers.
704704
norm_eps: Epsilon for the normalization.
705705
attention_levels: Indicate which level from num_channels contain an attention block.
706706
with_nonlocal_attn: If True, use non-local attention block.
707+
include_fc: whether to include the final linear layer in the attention block. Default to False.
708+
use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
707709
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
708710
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
709711
num_splits: Number of splits for the input tensor.
@@ -729,6 +731,8 @@ def __init__(
729731
print_info: bool = False,
730732
save_mem: bool = True,
731733
with_nonlocal_attn: bool = True,
734+
include_fc: bool = False,
735+
use_combined_linear: bool = False,
732736
use_flash_attention: bool = False,
733737
use_convtranspose: bool = False,
734738
) -> None:
@@ -758,7 +762,7 @@ def __init__(
758762

759763
if with_nonlocal_attn:
760764
blocks.append(
761-
ResBlock(
765+
AEKLResBlock(
762766
spatial_dims=spatial_dims,
763767
in_channels=reversed_block_out_channels[0],
764768
norm_num_groups=norm_num_groups,
@@ -767,16 +771,18 @@ def __init__(
767771
)
768772
)
769773
blocks.append(
770-
AttentionBlock(
774+
SpatialAttentionBlock(
771775
spatial_dims=spatial_dims,
772776
num_channels=reversed_block_out_channels[0],
773777
norm_num_groups=norm_num_groups,
774778
norm_eps=norm_eps,
779+
include_fc=include_fc,
780+
use_combined_linear=use_combined_linear,
775781
use_flash_attention=use_flash_attention,
776782
)
777783
)
778784
blocks.append(
779-
ResBlock(
785+
AEKLResBlock(
780786
spatial_dims=spatial_dims,
781787
in_channels=reversed_block_out_channels[0],
782788
norm_num_groups=norm_num_groups,
@@ -812,11 +818,13 @@ def __init__(
812818

813819
if reversed_attention_levels[i]:
814820
blocks.append(
815-
AttentionBlock(
821+
SpatialAttentionBlock(
816822
spatial_dims=spatial_dims,
817823
num_channels=block_in_ch,
818824
norm_num_groups=norm_num_groups,
819825
norm_eps=norm_eps,
826+
include_fc=include_fc,
827+
use_combined_linear=use_combined_linear,
820828
use_flash_attention=use_flash_attention,
821829
)
822830
)
@@ -870,7 +878,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
870878
return x
871879

872880

873-
class AutoencoderKlMaisi(AutoencoderKLType):
881+
class AutoencoderKlMaisi(AutoencoderKL):
874882
"""
875883
AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
876884
@@ -886,6 +894,8 @@ class AutoencoderKlMaisi(AutoencoderKLType):
886894
norm_eps: Epsilon for the normalization.
887895
with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
888896
with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
897+
include_fc: whether to include the final linear layer. Default to False.
898+
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
889899
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
890900
use_checkpointing: If True, use activation checkpointing.
891901
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
@@ -909,6 +919,8 @@ def __init__(
909919
norm_eps: float = 1e-6,
910920
with_encoder_nonlocal_attn: bool = False,
911921
with_decoder_nonlocal_attn: bool = False,
922+
include_fc: bool = False,
923+
use_combined_linear: bool = False,
912924
use_flash_attention: bool = False,
913925
use_checkpointing: bool = False,
914926
use_convtranspose: bool = False,
@@ -930,12 +942,14 @@ def __init__(
930942
norm_eps,
931943
with_encoder_nonlocal_attn,
932944
with_decoder_nonlocal_attn,
933-
use_flash_attention,
934945
use_checkpointing,
935946
use_convtranspose,
947+
include_fc,
948+
use_combined_linear,
949+
use_flash_attention,
936950
)
937951

938-
self.encoder = MaisiEncoder(
952+
self.encoder: nn.Module = MaisiEncoder(
939953
spatial_dims=spatial_dims,
940954
in_channels=in_channels,
941955
num_channels=num_channels,
@@ -945,6 +959,8 @@ def __init__(
945959
norm_eps=norm_eps,
946960
attention_levels=attention_levels,
947961
with_nonlocal_attn=with_encoder_nonlocal_attn,
962+
include_fc=include_fc,
963+
use_combined_linear=use_combined_linear,
948964
use_flash_attention=use_flash_attention,
949965
num_splits=num_splits,
950966
dim_split=dim_split,
@@ -953,7 +969,7 @@ def __init__(
953969
save_mem=save_mem,
954970
)
955971

956-
self.decoder = MaisiDecoder(
972+
self.decoder: nn.Module = MaisiDecoder(
957973
spatial_dims=spatial_dims,
958974
num_channels=num_channels,
959975
in_channels=latent_channels,
@@ -963,6 +979,8 @@ def __init__(
963979
norm_eps=norm_eps,
964980
attention_levels=attention_levels,
965981
with_nonlocal_attn=with_decoder_nonlocal_attn,
982+
include_fc=include_fc,
983+
use_combined_linear=use_combined_linear,
966984
use_flash_attention=use_flash_attention,
967985
use_convtranspose=use_convtranspose,
968986
num_splits=num_splits,

monai/networks/nets/autoencoderkl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def __init__(
532532
"`num_channels`."
533533
)
534534

535-
self.encoder = Encoder(
535+
self.encoder: nn.Module = Encoder(
536536
spatial_dims=spatial_dims,
537537
in_channels=in_channels,
538538
channels=channels,
@@ -546,7 +546,7 @@ def __init__(
546546
use_combined_linear=use_combined_linear,
547547
use_flash_attention=use_flash_attention,
548548
)
549-
self.decoder = Decoder(
549+
self.decoder: nn.Module = Decoder(
550550
spatial_dims=spatial_dims,
551551
channels=channels,
552552
in_channels=latent_channels,

tests/test_autoencoderkl_maisi.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,13 @@
1616
import torch
1717
from parameterized import parameterized
1818

19+
from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
1920
from monai.networks import eval_mode
2021
from monai.utils import optional_import
2122
from tests.utils import SkipIfBeforePyTorchVersion
2223

2324
tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
2425
_, has_einops = optional_import("einops")
25-
_, has_generative = optional_import("generative")
26-
27-
if has_generative:
28-
from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
2926

3027
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
3128

@@ -79,7 +76,6 @@
7976
CASES = CASES_NO_ATTENTION
8077

8178

82-
@unittest.skipUnless(has_generative, "monai-generative required")
8379
class TestAutoencoderKlMaisi(unittest.TestCase):
8480

8581
@parameterized.expand(CASES)

0 commit comments

Comments
 (0)