Skip to content

Commit 77304dd

Browse files
heyufan1995pre-commit-ci[bot]yiheng-wang-nvKumoLiu
authored
Add vista network (#7987)
Fixes # . ### Description Add VISTA3D model architecture to MONAI core ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: heyufan1995 <[email protected]> Signed-off-by: Yufan He <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Signed-off-by: Yiheng Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent e85580a commit 77304dd

File tree

7 files changed

+1189
-35
lines changed

7 files changed

+1189
-35
lines changed

docs/source/networks.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,11 @@ Nets
481481
.. autoclass:: SegResNetDS
482482
:members:
483483

484+
`SegResNetDS2`
485+
~~~~~~~~~~~~~~
486+
.. autoclass:: SegResNetDS2
487+
:members:
488+
484489
`SegResNetVAE`
485490
~~~~~~~~~~~~~~
486491
.. autoclass:: SegResNetVAE
@@ -556,6 +561,11 @@ Nets
556561
.. autoclass:: UNETR
557562
:members:
558563

564+
`VISTA3D`
565+
~~~~~~~~~
566+
.. autoclass:: VISTA3D
567+
:members:
568+
559569
`SwinUNETR`
560570
~~~~~~~~~~~
561571
.. autoclass:: SwinUNETR

monai/networks/nets/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
resnet200,
7777
)
7878
from .segresnet import SegResNet, SegResNetVAE
79-
from .segresnet_ds import SegResNetDS
79+
from .segresnet_ds import SegResNetDS, SegResNetDS2
8080
from .senet import (
8181
SENet,
8282
SEnet,
@@ -118,6 +118,7 @@
118118
from .unet import UNet, Unet
119119
from .unetr import UNETR
120120
from .varautoencoder import VarAutoEncoder
121+
from .vista3d import VISTA3D, vista3d132
121122
from .vit import ViT
122123
from .vitautoenc import ViTAutoEnc
123124
from .vnet import VNet

monai/networks/nets/segresnet_ds.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import copy
1415
from collections.abc import Callable
1516
from typing import Union
1617

@@ -23,7 +24,7 @@
2324
from monai.networks.layers.utils import get_act_layer, get_norm_layer
2425
from monai.utils import UpsampleMode, has_option
2526

26-
__all__ = ["SegResNetDS"]
27+
__all__ = ["SegResNetDS", "SegResNetDS2"]
2728

2829

2930
def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
@@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens
425426

426427
def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
427428
return self._forward(x)
429+
430+
431+
class SegResNetDS2(SegResNetDS):
432+
"""
433+
SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
434+
<https://arxiv.org/abs/2406.05285>`_.
435+
436+
Args:
437+
spatial_dims: spatial dimension of the input data. Defaults to 3.
438+
init_filters: number of output channels for initial convolution layer. Defaults to 32.
439+
in_channels: number of input channels for the network. Defaults to 1.
440+
out_channels: number of output channels for the network. Defaults to 2.
441+
act: activation type and arguments. Defaults to ``RELU``.
442+
norm: feature normalization type and arguments. Defaults to ``BATCH``.
443+
blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
444+
blocks_up: number of upsample blocks (optional).
445+
dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
446+
At dsdepth==1,only a single output is returned.
447+
preprocess: optional callable function to apply before the model's forward pass
448+
resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
449+
image spacing into an approximately isotropic space.
450+
Otherwise, by default, the kernel size and downsampling is always isotropic.
451+
452+
"""
453+
454+
def __init__(
455+
self,
456+
spatial_dims: int = 3,
457+
init_filters: int = 32,
458+
in_channels: int = 1,
459+
out_channels: int = 2,
460+
act: tuple | str = "relu",
461+
norm: tuple | str = "batch",
462+
blocks_down: tuple = (1, 2, 2, 4),
463+
blocks_up: tuple | None = None,
464+
dsdepth: int = 1,
465+
preprocess: nn.Module | Callable | None = None,
466+
upsample_mode: UpsampleMode | str = "deconv",
467+
resolution: tuple | None = None,
468+
):
469+
super().__init__(
470+
spatial_dims=spatial_dims,
471+
init_filters=init_filters,
472+
in_channels=in_channels,
473+
out_channels=out_channels,
474+
act=act,
475+
norm=norm,
476+
blocks_down=blocks_down,
477+
blocks_up=blocks_up,
478+
dsdepth=dsdepth,
479+
preprocess=preprocess,
480+
upsample_mode=upsample_mode,
481+
resolution=resolution,
482+
)
483+
484+
self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])
485+
486+
def forward( # type: ignore
487+
self, x: torch.Tensor, with_point: bool = True, with_label: bool = True
488+
) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
489+
"""
490+
Args:
491+
x: input tensor.
492+
with_point: if true, return the point branch output.
493+
with_label: if true, return the label branch output.
494+
"""
495+
if self.preprocess is not None:
496+
x = self.preprocess(x)
497+
498+
if not self.is_valid_shape(x):
499+
raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
500+
501+
x_down = self.encoder(x)
502+
503+
x_down.reverse()
504+
x = x_down.pop(0)
505+
506+
if len(x_down) == 0:
507+
x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
508+
509+
outputs: list[torch.Tensor] = []
510+
outputs_auto: list[torch.Tensor] = []
511+
x_ = x.clone()
512+
if with_point:
513+
i = 0
514+
for level in self.up_layers:
515+
x = level["upsample"](x)
516+
x = x + x_down[i]
517+
x = level["blocks"](x)
518+
519+
if len(self.up_layers) - i <= self.dsdepth:
520+
outputs.append(level["head"](x))
521+
i = i + 1
522+
523+
outputs.reverse()
524+
x = x_
525+
if with_label:
526+
i = 0
527+
for level in self.up_layers_auto:
528+
x = level["upsample"](x)
529+
x = x + x_down[i]
530+
x = level["blocks"](x)
531+
532+
if len(self.up_layers) - i <= self.dsdepth:
533+
outputs_auto.append(level["head"](x))
534+
i = i + 1
535+
536+
outputs_auto.reverse()
537+
538+
return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto
539+
540+
def set_auto_grad(self, auto_freeze=False, point_freeze=False):
541+
"""
542+
Args:
543+
auto_freeze: if true, freeze the image encoder and the auto-branch.
544+
point_freeze: if true, freeze the image encoder and the point-branch.
545+
"""
546+
for param in self.encoder.parameters():
547+
param.requires_grad = (not auto_freeze) and (not point_freeze)
548+
549+
for param in self.up_layers_auto.parameters():
550+
param.requires_grad = not auto_freeze
551+
552+
for param in self.up_layers.parameters():
553+
param.requires_grad = not point_freeze

0 commit comments

Comments
 (0)