22
22
import numpy as np
23
23
import torch
24
24
25
- from monai .config import DtypeLike , NdarrayTensor
25
+ from monai .config import DtypeLike
26
26
from monai .config .type_definitions import NdarrayOrTensor
27
27
from monai .transforms .transform import Randomizable , RandomizableTransform , Transform
28
28
from monai .transforms .utils import (
31
31
map_binary_to_indices ,
32
32
map_classes_to_indices ,
33
33
)
34
- from monai .transforms .utils_pytorch_numpy_unification import moveaxis
34
+ from monai .transforms .utils_pytorch_numpy_unification import in1d , moveaxis
35
35
from monai .utils import convert_to_numpy , convert_to_tensor , ensure_tuple , look_up_option , min_version , optional_import
36
36
from monai .utils .enums import TransformBackends
37
+ from monai .utils .misc import is_module_ver_at_least
37
38
from monai .utils .type_conversion import convert_data_type
38
39
39
40
PILImageImage , has_pil = optional_import ("PIL.Image" , name = "Image" )
@@ -445,6 +446,8 @@ class SqueezeDim(Transform):
445
446
Squeeze a unitary dimension.
446
447
"""
447
448
449
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
450
+
448
451
def __init__ (self , dim : Optional [int ] = 0 ) -> None :
449
452
"""
450
453
Args:
@@ -459,12 +462,17 @@ def __init__(self, dim: Optional[int] = 0) -> None:
459
462
raise TypeError (f"dim must be None or a int but is { type (dim ).__name__ } ." )
460
463
self .dim = dim
461
464
462
- def __call__ (self , img : NdarrayTensor ) -> NdarrayTensor :
465
+ def __call__ (self , img : NdarrayOrTensor ) -> NdarrayOrTensor :
463
466
"""
464
467
Args:
465
468
img: numpy arrays with required dimension `dim` removed
466
469
"""
467
- return img .squeeze (self .dim ) # type: ignore
470
+ if self .dim is None :
471
+ return img .squeeze ()
472
+ # for pytorch/numpy unification
473
+ if img .shape [self .dim ] != 1 :
474
+ raise ValueError ("Can only squeeze singleton dimension" )
475
+ return img .squeeze (self .dim )
468
476
469
477
470
478
class DataStats (Transform ):
@@ -475,6 +483,8 @@ class DataStats(Transform):
475
483
so it can be used in pre-processing and post-processing.
476
484
"""
477
485
486
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
487
+
478
488
def __init__ (
479
489
self ,
480
490
prefix : str = "Data" ,
@@ -523,14 +533,14 @@ def __init__(
523
533
524
534
def __call__ (
525
535
self ,
526
- img : NdarrayTensor ,
536
+ img : NdarrayOrTensor ,
527
537
prefix : Optional [str ] = None ,
528
538
data_type : Optional [bool ] = None ,
529
539
data_shape : Optional [bool ] = None ,
530
540
value_range : Optional [bool ] = None ,
531
541
data_value : Optional [bool ] = None ,
532
542
additional_info : Optional [Callable ] = None ,
533
- ) -> NdarrayTensor :
543
+ ) -> NdarrayOrTensor :
534
544
"""
535
545
Apply the transform to `img`, optionally take arguments similar to the class constructor.
536
546
"""
@@ -570,6 +580,8 @@ class SimulateDelay(Transform):
570
580
to sub-optimal design choices.
571
581
"""
572
582
583
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
584
+
573
585
def __init__ (self , delay_time : float = 0.0 ) -> None :
574
586
"""
575
587
Args:
@@ -579,7 +591,7 @@ def __init__(self, delay_time: float = 0.0) -> None:
579
591
super ().__init__ ()
580
592
self .delay_time : float = delay_time
581
593
582
- def __call__ (self , img : NdarrayTensor , delay_time : Optional [float ] = None ) -> NdarrayTensor :
594
+ def __call__ (self , img : NdarrayOrTensor , delay_time : Optional [float ] = None ) -> NdarrayOrTensor :
583
595
"""
584
596
Args:
585
597
img: data remain unchanged throughout this transform.
@@ -612,12 +624,14 @@ class Lambda(Transform):
612
624
613
625
"""
614
626
627
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
628
+
615
629
def __init__ (self , func : Optional [Callable ] = None ) -> None :
616
630
if func is not None and not callable (func ):
617
631
raise TypeError (f"func must be None or callable but is { type (func ).__name__ } ." )
618
632
self .func = func
619
633
620
- def __call__ (self , img : Union [ np . ndarray , torch . Tensor ] , func : Optional [Callable ] = None ):
634
+ def __call__ (self , img : NdarrayOrTensor , func : Optional [Callable ] = None ):
621
635
"""
622
636
Apply `self.func` to `img`.
623
637
@@ -648,14 +662,15 @@ class RandLambda(Lambda, RandomizableTransform):
648
662
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
649
663
650
664
For more details, please check :py:class:`monai.transforms.Lambda`.
651
-
652
665
"""
653
666
667
+ backend = Lambda .backend
668
+
654
669
def __init__ (self , func : Optional [Callable ] = None , prob : float = 1.0 ) -> None :
655
670
Lambda .__init__ (self = self , func = func )
656
671
RandomizableTransform .__init__ (self = self , prob = prob )
657
672
658
- def __call__ (self , img : Union [ np . ndarray , torch . Tensor ] , func : Optional [Callable ] = None ):
673
+ def __call__ (self , img : NdarrayOrTensor , func : Optional [Callable ] = None ):
659
674
self .randomize (img )
660
675
return super ().__call__ (img = img , func = func ) if self ._do_transform else img
661
676
@@ -679,6 +694,8 @@ class LabelToMask(Transform):
679
694
680
695
"""
681
696
697
+ backend = [TransformBackends .TORCH , TransformBackends .NUMPY ]
698
+
682
699
def __init__ ( # pytype: disable=annotation-type-mismatch
683
700
self ,
684
701
select_labels : Union [Sequence [int ], int ],
@@ -688,8 +705,11 @@ def __init__( # pytype: disable=annotation-type-mismatch
688
705
self .merge_channels = merge_channels
689
706
690
707
def __call__ (
691
- self , img : np .ndarray , select_labels : Optional [Union [Sequence [int ], int ]] = None , merge_channels : bool = False
692
- ):
708
+ self ,
709
+ img : NdarrayOrTensor ,
710
+ select_labels : Optional [Union [Sequence [int ], int ]] = None ,
711
+ merge_channels : bool = False ,
712
+ ) -> NdarrayOrTensor :
693
713
"""
694
714
Args:
695
715
select_labels: labels to generate mask from. for 1 channel label, the `select_labels`
@@ -706,26 +726,40 @@ def __call__(
706
726
if img .shape [0 ] > 1 :
707
727
data = img [[* select_labels ]]
708
728
else :
709
- data = np .where (np .in1d (img , select_labels ), True , False ).reshape (img .shape )
729
+ where = np .where if isinstance (img , np .ndarray ) else torch .where
730
+ if isinstance (img , np .ndarray ) or is_module_ver_at_least (torch , (1 , 8 , 0 )):
731
+ data = where (in1d (img , select_labels ), True , False ).reshape (img .shape )
732
+ # pre pytorch 1.8.0, need to use 1/0 instead of True/False
733
+ else :
734
+ data = where (
735
+ in1d (img , select_labels ), torch .tensor (1 , device = img .device ), torch .tensor (0 , device = img .device )
736
+ ).reshape (img .shape )
710
737
711
- return np .any (data , axis = 0 , keepdims = True ) if (merge_channels or self .merge_channels ) else data
738
+ if merge_channels or self .merge_channels :
739
+ if isinstance (img , np .ndarray ) or is_module_ver_at_least (torch , (1 , 8 , 0 )):
740
+ return data .any (0 )[None ]
741
+ # pre pytorch 1.8.0 compatibility
742
+ return data .to (torch .uint8 ).any (0 )[None ].to (bool ) # type: ignore
743
+
744
+ return data
712
745
713
746
714
747
class FgBgToIndices (Transform ):
715
- def __init__ (self , image_threshold : float = 0.0 , output_shape : Optional [Sequence [int ]] = None ) -> None :
716
- """
717
- Compute foreground and background of the input label data, return the indices.
718
- If no output_shape specified, output data will be 1 dim indices after flattening.
719
- This transform can help pre-compute foreground and background regions for other transforms.
720
- A typical usage is to randomly select foreground and background to crop.
721
- The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.
748
+ """
749
+ Compute foreground and background of the input label data, return the indices.
750
+ If no output_shape specified, output data will be 1 dim indices after flattening.
751
+ This transform can help pre-compute foreground and background regions for other transforms.
752
+ A typical usage is to randomly select foreground and background to crop.
753
+ The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.
722
754
723
- Args:
724
- image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
725
- determine the valid image content area and select background only in this area.
726
- output_shape: expected shape of output indices. if not None, unravel indices to specified shape.
755
+ Args:
756
+ image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
757
+ determine the valid image content area and select background only in this area.
758
+ output_shape: expected shape of output indices. if not None, unravel indices to specified shape.
727
759
728
- """
760
+ """
761
+
762
+ def __init__ (self , image_threshold : float = 0.0 , output_shape : Optional [Sequence [int ]] = None ) -> None :
729
763
self .image_threshold = image_threshold
730
764
self .output_shape = output_shape
731
765
0 commit comments