Skip to content

Commit a25b733

Browse files
authored
DataStats, LabelToMask, Lambda, RandLambda, SqueezeDim, is_module_ver_at_least (#2859)
* DataStats, LabelToMask, Lambda, RandLambda, SqueezeDim, is_module_ver_at_least Signed-off-by: Richard Brown <[email protected]>
1 parent 895592e commit a25b733

19 files changed

+332
-218
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,4 +518,4 @@
518518
weighted_patch_samples,
519519
zero_margins,
520520
)
521-
from .utils_pytorch_numpy_unification import moveaxis
521+
from .utils_pytorch_numpy_unification import in1d, moveaxis

monai/transforms/utility/array.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import numpy as np
2323
import torch
2424

25-
from monai.config import DtypeLike, NdarrayTensor
25+
from monai.config import DtypeLike
2626
from monai.config.type_definitions import NdarrayOrTensor
2727
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
2828
from monai.transforms.utils import (
@@ -31,9 +31,10 @@
3131
map_binary_to_indices,
3232
map_classes_to_indices,
3333
)
34-
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
34+
from monai.transforms.utils_pytorch_numpy_unification import in1d, moveaxis
3535
from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import
3636
from monai.utils.enums import TransformBackends
37+
from monai.utils.misc import is_module_ver_at_least
3738
from monai.utils.type_conversion import convert_data_type
3839

3940
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
@@ -445,6 +446,8 @@ class SqueezeDim(Transform):
445446
Squeeze a unitary dimension.
446447
"""
447448

449+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
450+
448451
def __init__(self, dim: Optional[int] = 0) -> None:
449452
"""
450453
Args:
@@ -459,12 +462,17 @@ def __init__(self, dim: Optional[int] = 0) -> None:
459462
raise TypeError(f"dim must be None or a int but is {type(dim).__name__}.")
460463
self.dim = dim
461464

462-
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
465+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
463466
"""
464467
Args:
465468
img: numpy arrays with required dimension `dim` removed
466469
"""
467-
return img.squeeze(self.dim) # type: ignore
470+
if self.dim is None:
471+
return img.squeeze()
472+
# for pytorch/numpy unification
473+
if img.shape[self.dim] != 1:
474+
raise ValueError("Can only squeeze singleton dimension")
475+
return img.squeeze(self.dim)
468476

469477

470478
class DataStats(Transform):
@@ -475,6 +483,8 @@ class DataStats(Transform):
475483
so it can be used in pre-processing and post-processing.
476484
"""
477485

486+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
487+
478488
def __init__(
479489
self,
480490
prefix: str = "Data",
@@ -523,14 +533,14 @@ def __init__(
523533

524534
def __call__(
525535
self,
526-
img: NdarrayTensor,
536+
img: NdarrayOrTensor,
527537
prefix: Optional[str] = None,
528538
data_type: Optional[bool] = None,
529539
data_shape: Optional[bool] = None,
530540
value_range: Optional[bool] = None,
531541
data_value: Optional[bool] = None,
532542
additional_info: Optional[Callable] = None,
533-
) -> NdarrayTensor:
543+
) -> NdarrayOrTensor:
534544
"""
535545
Apply the transform to `img`, optionally take arguments similar to the class constructor.
536546
"""
@@ -570,6 +580,8 @@ class SimulateDelay(Transform):
570580
to sub-optimal design choices.
571581
"""
572582

583+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
584+
573585
def __init__(self, delay_time: float = 0.0) -> None:
574586
"""
575587
Args:
@@ -579,7 +591,7 @@ def __init__(self, delay_time: float = 0.0) -> None:
579591
super().__init__()
580592
self.delay_time: float = delay_time
581593

582-
def __call__(self, img: NdarrayTensor, delay_time: Optional[float] = None) -> NdarrayTensor:
594+
def __call__(self, img: NdarrayOrTensor, delay_time: Optional[float] = None) -> NdarrayOrTensor:
583595
"""
584596
Args:
585597
img: data remain unchanged throughout this transform.
@@ -612,12 +624,14 @@ class Lambda(Transform):
612624
613625
"""
614626

627+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
628+
615629
def __init__(self, func: Optional[Callable] = None) -> None:
616630
if func is not None and not callable(func):
617631
raise TypeError(f"func must be None or callable but is {type(func).__name__}.")
618632
self.func = func
619633

620-
def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None):
634+
def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None):
621635
"""
622636
Apply `self.func` to `img`.
623637
@@ -648,14 +662,15 @@ class RandLambda(Lambda, RandomizableTransform):
648662
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
649663
650664
For more details, please check :py:class:`monai.transforms.Lambda`.
651-
652665
"""
653666

667+
backend = Lambda.backend
668+
654669
def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None:
655670
Lambda.__init__(self=self, func=func)
656671
RandomizableTransform.__init__(self=self, prob=prob)
657672

658-
def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None):
673+
def __call__(self, img: NdarrayOrTensor, func: Optional[Callable] = None):
659674
self.randomize(img)
660675
return super().__call__(img=img, func=func) if self._do_transform else img
661676

@@ -679,6 +694,8 @@ class LabelToMask(Transform):
679694
680695
"""
681696

697+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
698+
682699
def __init__( # pytype: disable=annotation-type-mismatch
683700
self,
684701
select_labels: Union[Sequence[int], int],
@@ -688,8 +705,11 @@ def __init__( # pytype: disable=annotation-type-mismatch
688705
self.merge_channels = merge_channels
689706

690707
def __call__(
691-
self, img: np.ndarray, select_labels: Optional[Union[Sequence[int], int]] = None, merge_channels: bool = False
692-
):
708+
self,
709+
img: NdarrayOrTensor,
710+
select_labels: Optional[Union[Sequence[int], int]] = None,
711+
merge_channels: bool = False,
712+
) -> NdarrayOrTensor:
693713
"""
694714
Args:
695715
select_labels: labels to generate mask from. for 1 channel label, the `select_labels`
@@ -706,26 +726,40 @@ def __call__(
706726
if img.shape[0] > 1:
707727
data = img[[*select_labels]]
708728
else:
709-
data = np.where(np.in1d(img, select_labels), True, False).reshape(img.shape)
729+
where = np.where if isinstance(img, np.ndarray) else torch.where
730+
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
731+
data = where(in1d(img, select_labels), True, False).reshape(img.shape)
732+
# pre pytorch 1.8.0, need to use 1/0 instead of True/False
733+
else:
734+
data = where(
735+
in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device)
736+
).reshape(img.shape)
710737

711-
return np.any(data, axis=0, keepdims=True) if (merge_channels or self.merge_channels) else data
738+
if merge_channels or self.merge_channels:
739+
if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)):
740+
return data.any(0)[None]
741+
# pre pytorch 1.8.0 compatibility
742+
return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore
743+
744+
return data
712745

713746

714747
class FgBgToIndices(Transform):
715-
def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None:
716-
"""
717-
Compute foreground and background of the input label data, return the indices.
718-
If no output_shape specified, output data will be 1 dim indices after flattening.
719-
This transform can help pre-compute foreground and background regions for other transforms.
720-
A typical usage is to randomly select foreground and background to crop.
721-
The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.
748+
"""
749+
Compute foreground and background of the input label data, return the indices.
750+
If no output_shape specified, output data will be 1 dim indices after flattening.
751+
This transform can help pre-compute foreground and background regions for other transforms.
752+
A typical usage is to randomly select foreground and background to crop.
753+
The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.
722754
723-
Args:
724-
image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
725-
determine the valid image content area and select background only in this area.
726-
output_shape: expected shape of output indices. if not None, unravel indices to specified shape.
755+
Args:
756+
image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
757+
determine the valid image content area and select background only in this area.
758+
output_shape: expected shape of output indices. if not None, unravel indices to specified shape.
727759
728-
"""
760+
"""
761+
762+
def __init__(self, image_threshold: float = 0.0, output_shape: Optional[Sequence[int]] = None) -> None:
729763
self.image_threshold = image_threshold
730764
self.output_shape = output_shape
731765

0 commit comments

Comments
 (0)