From 5e3a3aab06b30becebfcaa048b320f12dc1ed610 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 16 May 2024 10:34:53 +0100 Subject: [PATCH 01/10] Addition of: - spade_network, and SPADENet (VAE-GAN) - test_spade_vaegan (to test previously mentioned model) Modification of: - spade_diffusion_model_unet.py to change namings. - SPADE normalisation layer, to use get_norm_layer function instead of defining such layers directly. --- monai/networks/blocks/spade_norm.py | 7 +- monai/networks/nets/__init__.py | 1 + .../nets/spade_diffusion_model_unet.py | 8 +- monai/networks/nets/spade_network.py | 409 ++++++++++++++++++ tests/test_spade_autoencoderkl.py | 1 + tests/test_spade_vaegan.py | 131 ++++++ 6 files changed, 549 insertions(+), 8 deletions(-) create mode 100644 monai/networks/nets/spade_network.py create mode 100644 tests/test_spade_vaegan.py diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py index 8e082defe0..343dfa9ec0 100644 --- a/monai/networks/blocks/spade_norm.py +++ b/monai/networks/blocks/spade_norm.py @@ -15,7 +15,8 @@ import torch.nn as nn import torch.nn.functional as F -from monai.networks.blocks import ADN, Convolution +from monai.networks.blocks import Convolution +from monai.networks.layers.utils import get_norm_layer class SPADE(nn.Module): @@ -50,9 +51,7 @@ def __init__( norm_params = {} if len(norm_params) != 0: norm = (norm, norm_params) - self.param_free_norm = ADN( - act=None, dropout=0.0, norm=norm, norm_dim=spatial_dims, ordering="N", in_channels=norm_nc - ) + self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc) self.mlp_shared = Convolution( spatial_dims=spatial_dims, in_channels=label_nc, diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index 9101ab862e..c777fe6442 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -110,6 +110,7 @@ ) from .spade_autoencoderkl import SPADEAutoencoderKL from .spade_diffusion_model_unet import SPADEDiffusionModelUNet +from .spade_network import SPADENet from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR from .torchvision_fc import TorchVisionFCModel from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py index e019d21c11..ffb7d3cd67 100644 --- a/monai/networks/nets/spade_diffusion_model_unet.py +++ b/monai/networks/nets/spade_diffusion_model_unet.py @@ -53,7 +53,7 @@ __all__ = ["SPADEDiffusionModelUNet"] -class SPADEResnetBlock(nn.Module): +class SPADEDiffResBlock(nn.Module): """ Residual block with timestep conditioning and SPADE norm. Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE) @@ -238,7 +238,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -356,7 +356,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, @@ -491,7 +491,7 @@ def __init__( resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - SPADEResnetBlock( + SPADEDiffResBlock( spatial_dims=spatial_dims, in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py new file mode 100644 index 0000000000..2cb9efa0e4 --- /dev/null +++ b/monai/networks/nets/spade_network.py @@ -0,0 +1,409 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Sequence + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from monai.networks.blocks import Convolution +from monai.networks.layers import Act +from monai.utils.enums import StrEnum + +from monai.networks.blocks.spade_norm import SPADE + + +class UpsamplingModes(StrEnum): + bicubic = "bicubic" + nearest = "nearest" + bilinear = "bilinear" + + +class SPADENetResBlock(nn.Module): + """ + Creates a Residual Block with SPADE normalisation. + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks + spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks + norm: base normalisation type used on top of SPADE + kernel_size: convolutional kernel size + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + kernel_size: int = 3, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.int_channels = min(self.in_channels, self.out_channels) + self.learned_shortcut = self.in_channels != self.out_channels + self.conv_0 = Convolution( + spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None + ) + self.conv_1 = Convolution( + spatial_dims=spatial_dims, + in_channels=self.int_channels, + out_channels=self.out_channels, + act=None, + norm=None, + ) + self.activation = nn.LeakyReLU(0.2, False) + self.norm_0 = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + self.norm_1 = SPADE( + label_nc=label_nc, + norm_nc=self.int_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + if self.learned_shortcut: + self.conv_s = Convolution( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + act=None, + norm=None, + kernel_size=1, + ) + self.norm_s = SPADE( + label_nc=label_nc, + norm_nc=self.in_channels, + kernel_size=kernel_size, + spatial_dims=spatial_dims, + hidden_channels=spade_intermediate_channels, + norm=norm, + ) + + def forward(self, x, seg): + x_s = self.shortcut(x, seg) + dx = self.conv_0(self.activation(self.norm_0(x, seg))) + dx = self.conv_1(self.activation(self.norm_1(dx, seg))) + out = x_s + dx + return out + + def shortcut(self, x, seg): + if self.learned_shortcut: + x_s = self.conv_s(self.norm_s(x, seg)) + else: + x_s = x + return x_s + + +class SPADEEncoder(nn.Module): + """ + Encoding branch of a VAE compatible with a SPADE-like generator + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + z_dim: latent space dimension of the VAE containing the image sytle information + channels: number of output after each downsampling block + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + of the autoencoder (HxWx[D]) + kernel_size: convolutional kernel size + norm: normalisation layer type + act: activation type + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + z_dim: int, + channels: Sequence[int], + input_shape: Sequence[int], + kernel_size: int = 3, + norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), + ): + super().__init__() + self.in_channels = in_channels + self.z_dim = z_dim + self.channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.input_shape = input_shape + self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape] + blocks = [] + ch_init = self.in_channels + for _, ch_value in enumerate(channels): + blocks.append( + Convolution( + spatial_dims=spatial_dims, + in_channels=ch_init, + out_channels=ch_value, + strides=2, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + ) + ch_init = ch_value + + self.blocks = nn.ModuleList(blocks) + self.fc_mu = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + self.fc_var = nn.Linear( + in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return mu, logvar + + def encode(self, x): + for block in self.blocks: + x = block(x) + x = x.view(x.size(0), -1) + mu = self.fc_mu(x) + logvar = self.fc_var(x) + return self.reparameterize(mu, logvar) + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return eps.mul(std) + mu + + +class SPADEDecoder(nn.Module): + """ + Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch, + behaving like a GAN, or coupled to a SPADE encoder. + + Args: + label_nc: number of semantic labels + spatial_dims: number of spatial dimensions + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_gan: whether the decoder is going to be coupled to an autoencoder or not (true: not, false: yes) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: Sequence[int], + z_dim: int | None = None, + is_gan: bool = False, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_gan = is_gan + self.out_channels = out_channels + self.label_nc = label_nc + self.num_channels = channels + if len(input_shape) != spatial_dims: + raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape)) + for s_ind, s_ in enumerate(input_shape): + if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)): + raise ValueError( + "Each dimension of your input must be divisible by 2 ** (autoencoder depth)." + "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels)) + ) + self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] + + if self.is_gan: + self.fc = nn.Linear(label_nc, np.prod(self.latent_spatial_shape) * channels[0]) + else: + self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0]) + + blocks = [] + channels.append(self.out_channels) + self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) + for ch_ind, ch_value in enumerate(channels[:-1]): + blocks.append( + SPADENetResBlock( + spatial_dims=spatial_dims, + in_channels=ch_value, + out_channels=channels[ch_ind + 1], + label_nc=label_nc, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + kernel_size=kernel_size, + ) + ) + + self.blocks = torch.nn.ModuleList(blocks) + self.last_conv = Convolution( + spatial_dims=spatial_dims, + in_channels=channels[-1], + out_channels=out_channels, + padding=(kernel_size - 1) // 2, + kernel_size=kernel_size, + norm=None, + act=last_act, + ) + + def forward(self, seg, z: torch.Tensor = None): + if self.is_gan: + x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) + x = self.fc(x) + else: + if z is None: + z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=seg.get_device()) + x = self.fc(z) + x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) + + for res_block in self.blocks: + x = res_block(x, seg) + x = self.upsampling(x) + + x = self.last_conv(x) + return x + + +class SPADENet(nn.Module): + + """ + SPADE Network, implemented based on the code by Park, T et al. in + "Semantic Image Synthesis with Spatially-Adaptive Normalization" + (https://github.com/NVlabs/SPADE) + + Args: + spatial_dims: number of spatial dimensions + in_channels: number of input channels + out_channels: number of output channels + label_nc: number of semantic channels used for the SPADE normalisation blocks + input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers + channels: number of output after each downsampling block + z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) + is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false) + spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks + norm: base normalisation type + act: activation layer type + last_act: activation layer type for the last layer of the network (can differ from previous) + kernel_size: convolutional kernel size + upsampling_mode: upsampling mode (nearest, bilinear etc.) + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + label_nc: int, + input_shape: Sequence[int], + channels: Sequence[int], + z_dim: int | None = None, + is_vae: bool = True, + spade_intermediate_channels: int = 128, + norm: str | tuple = "INSTANCE", + act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + kernel_size: int = 3, + upsampling_mode: str = UpsamplingModes.nearest.value, + ): + super().__init__() + self.is_vae = is_vae + if self.is_vae and z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + + self.in_channels = in_channels + self.out_channels = out_channels + self.channels = channels + self.label_nc = label_nc + self.input_shape = input_shape + + if self.is_vae: + self.encoder = SPADEEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + channels=channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) + + decoder_channels = channels + decoder_channels.reverse() + + self.decoder = SPADEDecoder( + spatial_dims=spatial_dims, + out_channels=out_channels, + label_nc=label_nc, + input_shape=input_shape, + channels=decoder_channels, + z_dim=z_dim, + is_gan=not is_vae, + spade_intermediate_channels=spade_intermediate_channels, + norm=norm, + act=act, + last_act=last_act, + kernel_size=kernel_size, + upsampling_mode=upsampling_mode, + ) + + def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): + z = None + if self.is_vae: + z_mu, z_logvar = self.encoder(x) + z = self.encoder.reparameterize(z_mu, z_logvar) + kld_loss = self.kld_loss(z_mu, z_logvar) + return self.decoder(seg, z), kld_loss + else: + return (self.decoder(seg, z),) + + def encode(self, x: torch.Tensor): + return self.encoder.encode(x) + + def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): + return self.decoder(seg, z) diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py index 9353ceedc2..e6395946f5 100644 --- a/tests/test_spade_autoencoderkl.py +++ b/tests/test_spade_autoencoderkl.py @@ -15,6 +15,7 @@ from unittest import skipUnless import torch + from parameterized import parameterized from monai.networks import eval_mode diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py new file mode 100644 index 0000000000..92bf29a501 --- /dev/null +++ b/tests/test_spade_vaegan.py @@ -0,0 +1,131 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from monai.networks import eval_mode +from parameterized import parameterized + +from monai.networks.nets import SPADENet + +CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] +CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] + + +def create_semantic_data(shape: list, semantic_regions: int): + """ + To create semantic and image mock inputs for the network. + Args: + shape: input shape + semantic_regions: number of semantic regions + Returns: + """ + out_label = torch.zeros(shape) + out_image = torch.zeros(shape) + torch.randn(shape) * 0.01 + for i in range(1, semantic_regions): + shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape] + start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)] + if len(shape) == 2: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1]) + ] = (base_intensity + torch.randn(shape_square) * 0.1) + elif len(shape) == 3: + out_label[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = i + base_intensity = torch.ones(shape_square) * np.random.randn() + out_image[ + start_point[0] : (start_point[0] + shape_square[0]), + start_point[1] : (start_point[1] + shape_square[1]), + start_point[2] : (start_point[2] + shape_square[2]), + ] = (base_intensity + torch.randn(shape_square) * 0.1) + else: + ValueError("Supports only 2D and 3D tensors") + + # One hot encode label + out_label_ = torch.zeros([semantic_regions] + list(out_label.shape)) + for ch in range(semantic_regions): + out_label_[ch, ...] = out_label == ch + + return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) + + +class TestDiffusionModelUNet2D(unittest.TestCase): + @parameterized.expand(CASE_2D) + def test_forward_2d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) + self.assertEqual(list(out.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_2D_BIS) + def test_encoder_decoder(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out_z = net.encode(in_image) + self.assertEqual(list(out_z.shape), [1, 16]) + out_i = net.decode(in_label, out_z) + self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) + + @parameterized.expand(CASE_3D) + def test_forward_3d(self, input_param): + """ + Check that forward method is called correctly and output shape matches. + """ + net = SPADENet(*input_param) + in_label, in_image = create_semantic_data(input_param[4], input_param[3]) + with eval_mode(net): + out, kld = net(in_label, in_image) + self.assertEqual( + False, + True in torch.isnan(out) + or True in torch.isinf(out) + or True in torch.isinf(kld) + or True in torch.isinf(kld), + ) + self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) + + def test_shape_wrong(self): + """ + We input an input shape that isn't divisible by 2**(n downstream steps) + """ + with self.assertRaises(ValueError): + _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True) + + +if __name__ == "__main__": + unittest.main() From a4547fac26ec18866948a6035a5f78b8f5c3e61b Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 16 May 2024 13:15:13 +0100 Subject: [PATCH 02/10] Autofix changes. --- monai/networks/nets/spade_network.py | 5 ++--- tests/test_spade_autoencoderkl.py | 1 - tests/test_spade_vaegan.py | 2 +- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 2cb9efa0e4..87e0469f59 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -17,12 +17,12 @@ import torch import torch.nn as nn import torch.nn.functional as F + from monai.networks.blocks import Convolution +from monai.networks.blocks.spade_norm import SPADE from monai.networks.layers import Act from monai.utils.enums import StrEnum -from monai.networks.blocks.spade_norm import SPADE - class UpsamplingModes(StrEnum): bicubic = "bicubic" @@ -310,7 +310,6 @@ def forward(self, seg, z: torch.Tensor = None): class SPADENet(nn.Module): - """ SPADE Network, implemented based on the code by Park, T et al. in "Semantic Image Synthesis with Spatially-Adaptive Normalization" diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py index e6395946f5..9353ceedc2 100644 --- a/tests/test_spade_autoencoderkl.py +++ b/tests/test_spade_autoencoderkl.py @@ -15,7 +15,6 @@ from unittest import skipUnless import torch - from parameterized import parameterized from monai.networks import eval_mode diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 92bf29a501..8929043a08 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -15,9 +15,9 @@ import numpy as np import torch -from monai.networks import eval_mode from parameterized import parameterized +from monai.networks import eval_mode from monai.networks.nets import SPADENet CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] From c5834deaf91d6bb99d8a14e5976f6622018d97c6 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 16 May 2024 20:40:58 +0100 Subject: [PATCH 03/10] Addition of option for activation layer in the SPADE Decoder. It was, before, an argument, but it wasn't being used, set instead to LeakyRelu. Now, activation is a parameter that is passed. It can't be None. In addition, removal of the KLD loss object. Instead, if network is VAE, it outputs the mu and log variance in the forward so that the KLD can be calculated externally. --- monai/networks/nets/spade_network.py | 11 +++++++---- tests/test_spade_vaegan.py | 16 ++++++++++------ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 87e0469f59..9538d0f7c0 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -22,6 +22,7 @@ from monai.networks.blocks.spade_norm import SPADE from monai.networks.layers import Act from monai.utils.enums import StrEnum +from monai.networks.layers.utils import get_act_layer class UpsamplingModes(StrEnum): @@ -52,6 +53,7 @@ def __init__( label_nc: int, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), kernel_size: int = 3, ): super().__init__() @@ -69,7 +71,7 @@ def __init__( act=None, norm=None, ) - self.activation = nn.LeakyReLU(0.2, False) + self.activation = get_act_layer(act) self.norm_0 = SPADE( label_nc=label_nc, norm_nc=self.in_channels, @@ -240,6 +242,7 @@ def __init__( is_gan: bool = False, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, @@ -277,6 +280,7 @@ def __init__( spade_intermediate_channels=spade_intermediate_channels, norm=norm, kernel_size=kernel_size, + act = act ) ) @@ -344,7 +348,7 @@ def __init__( is_vae: bool = True, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", - act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), + act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}), kernel_size: int = 3, upsampling_mode: str = UpsamplingModes.nearest.value, @@ -396,8 +400,7 @@ def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): if self.is_vae: z_mu, z_logvar = self.encoder(x) z = self.encoder.reparameterize(z_mu, z_logvar) - kld_loss = self.kld_loss(z_mu, z_logvar) - return self.decoder(seg, z), kld_loss + return self.decoder(seg, z), z_mu, z_logvar else: return (self.decoder(seg, z),) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 8929043a08..56c057cda3 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -78,13 +78,15 @@ def test_forward_2d(self, input_param): net = SPADENet(*input_param) in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): - out, kld = net(in_label, in_image) + out, z_mu, z_logvar = net(in_label, in_image) self.assertEqual( False, True in torch.isnan(out) or True in torch.isinf(out) - or True in torch.isinf(kld) - or True in torch.isinf(kld), + or True in torch.isinf(z_mu) + or True in torch.isnan(z_mu) + or True in torch.isinf(z_logvar) + or True in torch.isnan(z_logvar), ) self.assertEqual(list(out.shape), [1, 1, 64, 64]) @@ -109,13 +111,15 @@ def test_forward_3d(self, input_param): net = SPADENet(*input_param) in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): - out, kld = net(in_label, in_image) + out, z_mu, z_logvar = net(in_label, in_image) self.assertEqual( False, True in torch.isnan(out) or True in torch.isinf(out) - or True in torch.isinf(kld) - or True in torch.isinf(kld), + or True in torch.isinf(z_mu) + or True in torch.isnan(z_mu) + or True in torch.isinf(z_logvar) + or True in torch.isnan(z_logvar), ) self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) From 02613868c4b0c59e027dd82e2c0a2cbb119b61f8 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Thu, 16 May 2024 20:42:01 +0100 Subject: [PATCH 04/10] Ran autofix! --- monai/networks/nets/spade_network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 9538d0f7c0..018fc92923 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -21,8 +21,8 @@ from monai.networks.blocks import Convolution from monai.networks.blocks.spade_norm import SPADE from monai.networks.layers import Act -from monai.utils.enums import StrEnum from monai.networks.layers.utils import get_act_layer +from monai.utils.enums import StrEnum class UpsamplingModes(StrEnum): @@ -280,7 +280,7 @@ def __init__( spade_intermediate_channels=spade_intermediate_channels, norm=norm, kernel_size=kernel_size, - act = act + act=act, ) ) From e5db55b939b0246ee9e66452326e17de44e047f4 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 17 May 2024 14:42:47 +0100 Subject: [PATCH 05/10] mypy fixes Signed-off-by: Mark Graham --- monai/networks/nets/spade_network.py | 35 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 018fc92923..8f708fe90b 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -237,8 +237,8 @@ def __init__( out_channels: int, label_nc: int, input_shape: Sequence[int], - channels: Sequence[int], - z_dim: int | None = None, + channels: list[int], + z_dim: int | None = None, is_gan: bool = False, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", @@ -265,6 +265,7 @@ def __init__( if self.is_gan: self.fc = nn.Linear(label_nc, np.prod(self.latent_spatial_shape) * channels[0]) else: + assert z_dim is not None self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0]) blocks = [] @@ -295,7 +296,7 @@ def __init__( act=last_act, ) - def forward(self, seg, z: torch.Tensor = None): + def forward(self, seg, z: torch.Tensor | None = None): if self.is_gan: x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) x = self.fc(x) @@ -343,7 +344,7 @@ def __init__( out_channels: int, label_nc: int, input_shape: Sequence[int], - channels: Sequence[int], + channels: list[int], z_dim: int | None = None, is_vae: bool = True, spade_intermediate_channels: int = 128, @@ -355,9 +356,6 @@ def __init__( ): super().__init__() self.is_vae = is_vae - if self.is_vae and z_dim is None: - ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") - self.in_channels = in_channels self.out_channels = out_channels self.channels = channels @@ -365,16 +363,19 @@ def __init__( self.input_shape = input_shape if self.is_vae: - self.encoder = SPADEEncoder( - spatial_dims=spatial_dims, - in_channels=in_channels, - z_dim=z_dim, - channels=channels, - input_shape=input_shape, - kernel_size=kernel_size, - norm=norm, - act=act, - ) + if z_dim is None: + ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.") + else: + self.encoder = SPADEEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + z_dim=z_dim, + channels=channels, + input_shape=input_shape, + kernel_size=kernel_size, + norm=norm, + act=act, + ) decoder_channels = channels decoder_channels.reverse() From 0fdf990be36d57962e01e7db5c8fd87066beb186 Mon Sep 17 00:00:00 2001 From: Mark Graham Date: Fri, 17 May 2024 14:52:30 +0100 Subject: [PATCH 06/10] formatting Signed-off-by: Mark Graham --- monai/networks/nets/spade_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 8f708fe90b..d16ef0c96d 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -238,7 +238,7 @@ def __init__( label_nc: int, input_shape: Sequence[int], channels: list[int], - z_dim: int | None = None, + z_dim: int | None = None, is_gan: bool = False, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", From 9ada8d2502ef0db06a486f4a724fedc10745a7a1 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Fri, 17 May 2024 15:27:17 +0100 Subject: [PATCH 07/10] DCO Remediation Commit for virginiafdez I, virginiafdez , hereby add my Signed-off-by to this commit: 5e3a3aab06b30becebfcaa048b320f12dc1ed610 I, virginiafdez , hereby add my Signed-off-by to this commit: a4547fac26ec18866948a6035a5f78b8f5c3e61b I, virginiafdez , hereby add my Signed-off-by to this commit: c5834deaf91d6bb99d8a14e5976f6622018d97c6 I, virginiafdez , hereby add my Signed-off-by to this commit: 02613868c4b0c59e027dd82e2c0a2cbb119b61f8 Signed-off-by: virginiafdez --- tests/test_spade_vaegan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 56c057cda3..56d4f54d6d 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -30,7 +30,7 @@ def create_semantic_data(shape: list, semantic_regions: int): To create semantic and image mock inputs for the network. Args: shape: input shape - semantic_regions: number of semantic regions + semantic_regions: number of semantic region Returns: """ out_label = torch.zeros(shape) From 77a042cd23bf5ab94a6733c5c22163dada67cb81 Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 20 May 2024 11:29:18 +0100 Subject: [PATCH 08/10] Modification of test_spade_vaegan.py as per review (change of asserts). Signed-off-by: virginiafdez --- tests/test_spade_vaegan.py | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index 56d4f54d6d..dae40510e5 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -79,15 +79,9 @@ def test_forward_2d(self, input_param): in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out, z_mu, z_logvar = net(in_label, in_image) - self.assertEqual( - False, - True in torch.isnan(out) - or True in torch.isinf(out) - or True in torch.isinf(z_mu) - or True in torch.isnan(z_mu) - or True in torch.isinf(z_logvar) - or True in torch.isnan(z_logvar), - ) + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) self.assertEqual(list(out.shape), [1, 1, 64, 64]) @parameterized.expand(CASE_2D_BIS) @@ -112,15 +106,9 @@ def test_forward_3d(self, input_param): in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out, z_mu, z_logvar = net(in_label, in_image) - self.assertEqual( - False, - True in torch.isnan(out) - or True in torch.isinf(out) - or True in torch.isinf(z_mu) - or True in torch.isnan(z_mu) - or True in torch.isinf(z_logvar) - or True in torch.isnan(z_logvar), - ) + self.assertTrue(torch.all(torch.isfinite(out))) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) def test_shape_wrong(self): From d69023cefa31fc43f881415dfad2fde4f2508d5f Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Mon, 27 May 2024 16:45:28 +0100 Subject: [PATCH 09/10] - Modification of the spade_network to include suggested changes: more informative errors when z_dim is None and network is not GAN, addition of docstrings in forward method and __all__. - Modification of flag is_gan in the Decoder, to change it to opposite flag is_vae, to make it match with the flag of the SPADE network. - Modification of functionality on NOT is_vae mode. Initially, a random Gaussian noise vector was drawn and passed to an input Linear layer. Nonetheless, the original SPADE code starts from an interpolated version of the semantic map (deterministically) and passes it to a conv layer. self.fc is changed to a Conv layer when is_vae is False. Because mypy does not allow for self.fc to be Linear or Convolution under different attribute values, self.fc > self.conv_init when is_vae is False. - Modification of tests to incorporate suggestions: name of the tests were unsuitable, there were missing scenarios for when is_vae = False. Signed-off-by: virginiafdez --- monai/networks/nets/spade_network.py | 47 +++++++++++++++++++++------- tests/test_spade_vaegan.py | 41 +++++++++++++++++------- 2 files changed, 64 insertions(+), 24 deletions(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index d16ef0c96d..85a2b3bb47 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -24,6 +24,8 @@ from monai.networks.layers.utils import get_act_layer from monai.utils.enums import StrEnum +__all__ = ["SPADENet"] + class UpsamplingModes(StrEnum): bicubic = "bicubic" @@ -222,7 +224,7 @@ class SPADEDecoder(nn.Module): input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers channels: number of output after each downsampling block z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used) - is_gan: whether the decoder is going to be coupled to an autoencoder or not (true: not, false: yes) + is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no) spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks norm: base normalisation type act: activation layer type @@ -239,7 +241,7 @@ def __init__( input_shape: Sequence[int], channels: list[int], z_dim: int | None = None, - is_gan: bool = False, + is_vae: bool = True, spade_intermediate_channels: int = 128, norm: str | tuple = "INSTANCE", act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}), @@ -248,7 +250,7 @@ def __init__( upsampling_mode: str = UpsamplingModes.nearest.value, ): super().__init__() - self.is_gan = is_gan + self.is_vae = is_vae self.out_channels = out_channels self.label_nc = label_nc self.num_channels = channels @@ -262,12 +264,19 @@ def __init__( ) self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape] - if self.is_gan: - self.fc = nn.Linear(label_nc, np.prod(self.latent_spatial_shape) * channels[0]) + if not self.is_vae: + self.conv_init = Convolution( + spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size + ) + elif self.is_vae and z_dim is None: + raise ValueError( + "If the network is used in VAE-GAN mode, parameter z_dim " + "(number of latent channels in the VAE) must be populated." + ) else: - assert z_dim is not None self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0]) + self.z_dim = z_dim blocks = [] channels.append(self.out_channels) self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode) @@ -297,12 +306,23 @@ def __init__( ) def forward(self, seg, z: torch.Tensor | None = None): - if self.is_gan: + """ + Args: + seg: input BxCxHxW[xD] semantic map on which the output is conditioned on + z: latent vector output by the encoder if self.is_vae is True. When is_vae is + False, z is a random noise vector. + + Returns: + + """ + if not self.is_vae: x = F.interpolate(seg, size=tuple(self.latent_spatial_shape)) - x = self.fc(x) + x = self.conv_init(x) else: - if z is None: - z = torch.randn(seg.size(0), self.opt.z_dim, dtype=torch.float32, device=seg.get_device()) + if ( + z is None and self.z_dim is not None + ): # is_vae is Truee, but we can use the VAE-GAN as GAN in this function. + z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device()) x = self.fc(z) x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape) @@ -387,7 +407,7 @@ def __init__( input_shape=input_shape, channels=decoder_channels, z_dim=z_dim, - is_gan=not is_vae, + is_vae=is_vae, spade_intermediate_channels=spade_intermediate_channels, norm=norm, act=act, @@ -406,7 +426,10 @@ def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None): return (self.decoder(seg, z),) def encode(self, x: torch.Tensor): - return self.encoder.encode(x) + if self.is_vae: + return self.encoder.encode(x) + else: + return None def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None): return self.decoder(seg, z) diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py index dae40510e5..3fdb9b74cb 100644 --- a/tests/test_spade_vaegan.py +++ b/tests/test_spade_vaegan.py @@ -20,9 +20,14 @@ from monai.networks import eval_mode from monai.networks.nets import SPADENet -CASE_2D = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] -CASE_2D_BIS = [[[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]]] -CASE_3D = [[[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]]] +CASE_2D = [ + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]], + [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], None, False]], +] +CASE_3D = [ + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]], + [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], None, False]], +] def create_semantic_data(shape: list, semantic_regions: int): @@ -69,7 +74,7 @@ def create_semantic_data(shape: list, semantic_regions: int): return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0) -class TestDiffusionModelUNet2D(unittest.TestCase): +class TestSpadeNet(unittest.TestCase): @parameterized.expand(CASE_2D) def test_forward_2d(self, input_param): """ @@ -78,13 +83,18 @@ def test_forward_2d(self, input_param): net = SPADENet(*input_param) in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): - out, z_mu, z_logvar = net(in_label, in_image) + if not net.is_vae: + out = net(in_label, in_image) + out = out[0] + else: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + self.assertTrue(torch.all(torch.isfinite(out))) - self.assertTrue(torch.all(torch.isfinite(z_mu))) - self.assertTrue(torch.all(torch.isfinite(z_logvar))) self.assertEqual(list(out.shape), [1, 1, 64, 64]) - @parameterized.expand(CASE_2D_BIS) + @parameterized.expand(CASE_2D) def test_encoder_decoder(self, input_param): """ Check that forward method is called correctly and output shape matches. @@ -93,7 +103,10 @@ def test_encoder_decoder(self, input_param): in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): out_z = net.encode(in_image) - self.assertEqual(list(out_z.shape), [1, 16]) + if net.is_vae: + self.assertEqual(list(out_z.shape), [1, 16]) + else: + self.assertEqual(out_z, None) out_i = net.decode(in_label, out_z) self.assertEqual(list(out_i.shape), [1, 1, 64, 64]) @@ -105,10 +118,14 @@ def test_forward_3d(self, input_param): net = SPADENet(*input_param) in_label, in_image = create_semantic_data(input_param[4], input_param[3]) with eval_mode(net): - out, z_mu, z_logvar = net(in_label, in_image) + if net.is_vae: + out, z_mu, z_logvar = net(in_label, in_image) + self.assertTrue(torch.all(torch.isfinite(z_mu))) + self.assertTrue(torch.all(torch.isfinite(z_logvar))) + else: + out = net(in_label, in_image) + out = out[0] self.assertTrue(torch.all(torch.isfinite(out))) - self.assertTrue(torch.all(torch.isfinite(z_mu))) - self.assertTrue(torch.all(torch.isfinite(z_logvar))) self.assertEqual(list(out.shape), [1, 1, 64, 64, 64]) def test_shape_wrong(self): From e6b56c6e7b431ed802e4fcf1bb7d52788566ccbc Mon Sep 17 00:00:00 2001 From: virginiafdez Date: Sun, 2 Jun 2024 20:52:50 +0100 Subject: [PATCH 10/10] Clarification of functionality in SPADE network decoder. Signed-off-by: virginiafdez --- monai/networks/nets/spade_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py index 85a2b3bb47..9164541f27 100644 --- a/monai/networks/nets/spade_network.py +++ b/monai/networks/nets/spade_network.py @@ -321,7 +321,7 @@ def forward(self, seg, z: torch.Tensor | None = None): else: if ( z is None and self.z_dim is not None - ): # is_vae is Truee, but we can use the VAE-GAN as GAN in this function. + ): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well. z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device()) x = self.fc(z) x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape)