|
11 | 11 |
|
12 | 12 | from __future__ import annotations
|
13 | 13 |
|
| 14 | +import copy |
14 | 15 | from collections.abc import Callable
|
15 | 16 | from typing import Union
|
16 | 17 |
|
|
23 | 24 | from monai.networks.layers.utils import get_act_layer, get_norm_layer
|
24 | 25 | from monai.utils import UpsampleMode, has_option
|
25 | 26 |
|
26 |
| -__all__ = ["SegResNetDS"] |
| 27 | +__all__ = ["SegResNetDS", "SegResNetDS2"] |
27 | 28 |
|
28 | 29 |
|
29 | 30 | def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
|
@@ -425,3 +426,128 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens
|
425 | 426 |
|
426 | 427 | def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
|
427 | 428 | return self._forward(x)
|
| 429 | + |
| 430 | + |
| 431 | +class SegResNetDS2(SegResNetDS): |
| 432 | + """ |
| 433 | + SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D |
| 434 | + <https://arxiv.org/abs/2406.05285>`_. |
| 435 | +
|
| 436 | + Args: |
| 437 | + spatial_dims: spatial dimension of the input data. Defaults to 3. |
| 438 | + init_filters: number of output channels for initial convolution layer. Defaults to 32. |
| 439 | + in_channels: number of input channels for the network. Defaults to 1. |
| 440 | + out_channels: number of output channels for the network. Defaults to 2. |
| 441 | + act: activation type and arguments. Defaults to ``RELU``. |
| 442 | + norm: feature normalization type and arguments. Defaults to ``BATCH``. |
| 443 | + blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``. |
| 444 | + blocks_up: number of upsample blocks (optional). |
| 445 | + dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level. |
| 446 | + At dsdepth==1,only a single output is returned. |
| 447 | + preprocess: optional callable function to apply before the model's forward pass |
| 448 | + resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring |
| 449 | + image spacing into an approximately isotropic space. |
| 450 | + Otherwise, by default, the kernel size and downsampling is always isotropic. |
| 451 | +
|
| 452 | + """ |
| 453 | + |
| 454 | + def __init__( |
| 455 | + self, |
| 456 | + spatial_dims: int = 3, |
| 457 | + init_filters: int = 32, |
| 458 | + in_channels: int = 1, |
| 459 | + out_channels: int = 2, |
| 460 | + act: tuple | str = "relu", |
| 461 | + norm: tuple | str = "batch", |
| 462 | + blocks_down: tuple = (1, 2, 2, 4), |
| 463 | + blocks_up: tuple | None = None, |
| 464 | + dsdepth: int = 1, |
| 465 | + preprocess: nn.Module | Callable | None = None, |
| 466 | + upsample_mode: UpsampleMode | str = "deconv", |
| 467 | + resolution: tuple | None = None, |
| 468 | + ): |
| 469 | + super().__init__( |
| 470 | + spatial_dims=spatial_dims, |
| 471 | + init_filters=init_filters, |
| 472 | + in_channels=in_channels, |
| 473 | + out_channels=out_channels, |
| 474 | + act=act, |
| 475 | + norm=norm, |
| 476 | + blocks_down=blocks_down, |
| 477 | + blocks_up=blocks_up, |
| 478 | + dsdepth=dsdepth, |
| 479 | + preprocess=preprocess, |
| 480 | + upsample_mode=upsample_mode, |
| 481 | + resolution=resolution, |
| 482 | + ) |
| 483 | + |
| 484 | + self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers]) |
| 485 | + |
| 486 | + def forward( # type: ignore |
| 487 | + self, x: torch.Tensor, with_point: bool = True, with_label: bool = True |
| 488 | + ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]: |
| 489 | + """ |
| 490 | + Args: |
| 491 | + x: input tensor. |
| 492 | + with_point: if true, return the point branch output. |
| 493 | + with_label: if true, return the label branch output. |
| 494 | + """ |
| 495 | + if self.preprocess is not None: |
| 496 | + x = self.preprocess(x) |
| 497 | + |
| 498 | + if not self.is_valid_shape(x): |
| 499 | + raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}") |
| 500 | + |
| 501 | + x_down = self.encoder(x) |
| 502 | + |
| 503 | + x_down.reverse() |
| 504 | + x = x_down.pop(0) |
| 505 | + |
| 506 | + if len(x_down) == 0: |
| 507 | + x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)] |
| 508 | + |
| 509 | + outputs: list[torch.Tensor] = [] |
| 510 | + outputs_auto: list[torch.Tensor] = [] |
| 511 | + x_ = x.clone() |
| 512 | + if with_point: |
| 513 | + i = 0 |
| 514 | + for level in self.up_layers: |
| 515 | + x = level["upsample"](x) |
| 516 | + x = x + x_down[i] |
| 517 | + x = level["blocks"](x) |
| 518 | + |
| 519 | + if len(self.up_layers) - i <= self.dsdepth: |
| 520 | + outputs.append(level["head"](x)) |
| 521 | + i = i + 1 |
| 522 | + |
| 523 | + outputs.reverse() |
| 524 | + x = x_ |
| 525 | + if with_label: |
| 526 | + i = 0 |
| 527 | + for level in self.up_layers_auto: |
| 528 | + x = level["upsample"](x) |
| 529 | + x = x + x_down[i] |
| 530 | + x = level["blocks"](x) |
| 531 | + |
| 532 | + if len(self.up_layers) - i <= self.dsdepth: |
| 533 | + outputs_auto.append(level["head"](x)) |
| 534 | + i = i + 1 |
| 535 | + |
| 536 | + outputs_auto.reverse() |
| 537 | + |
| 538 | + return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto |
| 539 | + |
| 540 | + def set_auto_grad(self, auto_freeze=False, point_freeze=False): |
| 541 | + """ |
| 542 | + Args: |
| 543 | + auto_freeze: if true, freeze the image encoder and the auto-branch. |
| 544 | + point_freeze: if true, freeze the image encoder and the point-branch. |
| 545 | + """ |
| 546 | + for param in self.encoder.parameters(): |
| 547 | + param.requires_grad = (not auto_freeze) and (not point_freeze) |
| 548 | + |
| 549 | + for param in self.up_layers_auto.parameters(): |
| 550 | + param.requires_grad = not auto_freeze |
| 551 | + |
| 552 | + for param in self.up_layers.parameters(): |
| 553 | + param.requires_grad = not point_freeze |
0 commit comments