24
24
import torch
25
25
26
26
from monai .config import DtypeLike , KeysCollection , NdarrayTensor
27
+ from monai .data .utils import no_collation
27
28
from monai .transforms .inverse import InvertibleTransform
28
- from monai .transforms .transform import MapTransform , Randomizable
29
+ from monai .transforms .transform import MapTransform , Randomizable , RandomizableTransform
29
30
from monai .transforms .utility .array import (
30
31
AddChannel ,
31
32
AsChannelFirst ,
@@ -833,7 +834,7 @@ def __call__(self, data):
833
834
return d
834
835
835
836
836
- class Lambdad (MapTransform ):
837
+ class Lambdad (MapTransform , InvertibleTransform ):
837
838
"""
838
839
Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`.
839
840
@@ -852,51 +853,110 @@ class Lambdad(MapTransform):
852
853
See also: :py:class:`monai.transforms.compose.MapTransform`
853
854
func: Lambda/function to be applied. It also can be a sequence of Callable,
854
855
each element corresponds to a key in ``keys``.
856
+ inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
857
+ It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
855
858
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
856
859
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
857
860
allow_missing_keys: don't raise exception if key is missing.
861
+
862
+ Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
863
+ image's original size. If need these complicated information, please write a new InvertibleTransform directly.
864
+
858
865
"""
859
866
860
867
def __init__ (
861
868
self ,
862
869
keys : KeysCollection ,
863
870
func : Union [Sequence [Callable ], Callable ],
871
+ inv_func : Union [Sequence [Callable ], Callable ] = no_collation ,
864
872
overwrite : Union [Sequence [bool ], bool ] = True ,
865
873
allow_missing_keys : bool = False ,
866
874
) -> None :
867
875
super ().__init__ (keys , allow_missing_keys )
868
876
self .func = ensure_tuple_rep (func , len (self .keys ))
877
+ self .inv_func = ensure_tuple_rep (inv_func , len (self .keys ))
869
878
self .overwrite = ensure_tuple_rep (overwrite , len (self .keys ))
870
879
self ._lambd = Lambda ()
871
880
881
+ def _transform (self , data : Any , func : Callable ):
882
+ return self ._lambd (data , func = func )
883
+
872
884
def __call__ (self , data ):
873
885
d = dict (data )
874
886
for key , func , overwrite in self .key_iterator (d , self .func , self .overwrite ):
875
- ret = self ._lambd (d [key ], func = func )
887
+ ret = self ._transform (data = d [key ], func = func )
888
+ if overwrite :
889
+ d [key ] = ret
890
+ self .push_transform (d , key )
891
+ return d
892
+
893
+ def _inverse_transform (self , transform_info : Dict , data : Any , func : Callable ):
894
+ return self ._lambd (data , func = func )
895
+
896
+ def inverse (self , data ):
897
+ d = deepcopy (dict (data ))
898
+ for key , inv_func , overwrite in self .key_iterator (d , self .inv_func , self .overwrite ):
899
+ transform = self .get_most_recent_transform (d , key )
900
+ ret = self ._inverse_transform (transform_info = transform , data = d [key ], func = inv_func )
876
901
if overwrite :
877
902
d [key ] = ret
903
+ self .pop_transform (d , key )
878
904
return d
879
905
880
906
881
- class RandLambdad (Lambdad , Randomizable ):
907
+ class RandLambdad (Lambdad , RandomizableTransform ):
882
908
"""
883
- Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
884
- It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
909
+ Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic,
910
+ or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results.
885
911
886
912
Args:
887
913
keys: keys of the corresponding items to be transformed.
888
914
See also: :py:class:`monai.transforms.compose.MapTransform`
889
915
func: Lambda/function to be applied. It also can be a sequence of Callable,
890
916
each element corresponds to a key in ``keys``.
917
+ inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
918
+ It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
891
919
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
892
920
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
921
+ prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
922
+ note that all the data specified by `keys` will share the same random probability to execute or not.
923
+ allow_missing_keys: don't raise exception if key is missing.
893
924
894
925
For more details, please check :py:class:`monai.transforms.Lambdad`.
895
926
927
+ Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
928
+ image's original size. If need these complicated information, please write a new InvertibleTransform directly.
929
+
896
930
"""
897
931
898
- def randomize (self , data : Any ) -> None :
899
- pass
932
+ def __init__ (
933
+ self ,
934
+ keys : KeysCollection ,
935
+ func : Union [Sequence [Callable ], Callable ],
936
+ inv_func : Union [Sequence [Callable ], Callable ] = no_collation ,
937
+ overwrite : Union [Sequence [bool ], bool ] = True ,
938
+ prob : float = 1.0 ,
939
+ allow_missing_keys : bool = False ,
940
+ ) -> None :
941
+ Lambdad .__init__ (
942
+ self = self ,
943
+ keys = keys ,
944
+ func = func ,
945
+ inv_func = inv_func ,
946
+ overwrite = overwrite ,
947
+ allow_missing_keys = allow_missing_keys ,
948
+ )
949
+ RandomizableTransform .__init__ (self = self , prob = prob , do_transform = True )
950
+
951
+ def _transform (self , data : Any , func : Callable ):
952
+ return self ._lambd (data , func = func ) if self ._do_transform else data
953
+
954
+ def __call__ (self , data ):
955
+ self .randomize (data )
956
+ return super ().__call__ (data )
957
+
958
+ def _inverse_transform (self , transform_info : Dict , data : Any , func : Callable ):
959
+ return self ._lambd (data , func = func ) if transform_info [InverseKeys .DO_TRANSFORM ] else data
900
960
901
961
902
962
class LabelToMaskd (MapTransform ):
0 commit comments