Skip to content

Commit 607a31a

Browse files
authored
2648 Enhance RandLambda to execute deterministic transforms for random part of dataset (#2667)
* [DLMED] add RandCompose Signed-off-by: Nic Ma <[email protected]> * [DLMED] add unit tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] change to enhance RandLambda Signed-off-by: Nic Ma <[email protected]> * [DLMED] remove RandCompose Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix format Signed-off-by: Nic Ma <[email protected]> * [DLMED] enhance doc Signed-off-by: Nic Ma <[email protected]> * [DLMED] add inverse operation Signed-off-by: Nic Ma <[email protected]> * [DLMED] add more tests Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix subprogress issue Signed-off-by: Nic Ma <[email protected]>
1 parent 390fe7f commit 607a31a

File tree

7 files changed

+173
-9
lines changed

7 files changed

+173
-9
lines changed

docs/source/transforms.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,12 @@ Utility
604604
:members:
605605
:special-members: __call__
606606

607+
`RandLambda`
608+
""""""""""""
609+
.. autoclass:: RandLambda
610+
:members:
611+
:special-members: __call__
612+
607613
`LabelToMask`
608614
"""""""""""""
609615
.. autoclass:: LabelToMask

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@
323323
LabelToMask,
324324
Lambda,
325325
MapLabelValue,
326+
RandLambda,
326327
RemoveRepeatedChannel,
327328
RepeatChannel,
328329
SimulateDelay,

monai/transforms/utility/array.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424

2525
from monai.config import DtypeLike, NdarrayTensor
26-
from monai.transforms.transform import Randomizable, Transform
26+
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
2727
from monai.transforms.utils import (
2828
convert_to_numpy,
2929
convert_to_tensor,
@@ -58,6 +58,7 @@
5858
"DataStats",
5959
"SimulateDelay",
6060
"Lambda",
61+
"RandLambda",
6162
"LabelToMask",
6263
"FgBgToIndices",
6364
"ClassesToIndices",
@@ -617,6 +618,28 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable
617618
raise ValueError("Incompatible values: func=None and self.func=None.")
618619

619620

621+
class RandLambda(Lambda, RandomizableTransform):
622+
"""
623+
Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic,
624+
or randomly execute the function based on `prob`.
625+
626+
Args:
627+
func: Lambda/function to be applied.
628+
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
629+
630+
For more details, please check :py:class:`monai.transforms.Lambda`.
631+
632+
"""
633+
634+
def __init__(self, func: Optional[Callable] = None, prob: float = 1.0) -> None:
635+
Lambda.__init__(self=self, func=func)
636+
RandomizableTransform.__init__(self=self, prob=prob)
637+
638+
def __call__(self, img: Union[np.ndarray, torch.Tensor], func: Optional[Callable] = None):
639+
self.randomize(img)
640+
return super().__call__(img=img, func=func) if self._do_transform else img
641+
642+
620643
class LabelToMask(Transform):
621644
"""
622645
Convert labels to mask for other tasks. A typical usage is to convert segmentation labels

monai/transforms/utility/dictionary.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
import torch
2525

2626
from monai.config import DtypeLike, KeysCollection, NdarrayTensor
27+
from monai.data.utils import no_collation
2728
from monai.transforms.inverse import InvertibleTransform
28-
from monai.transforms.transform import MapTransform, Randomizable
29+
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
2930
from monai.transforms.utility.array import (
3031
AddChannel,
3132
AsChannelFirst,
@@ -833,7 +834,7 @@ def __call__(self, data):
833834
return d
834835

835836

836-
class Lambdad(MapTransform):
837+
class Lambdad(MapTransform, InvertibleTransform):
837838
"""
838839
Dictionary-based wrapper of :py:class:`monai.transforms.Lambda`.
839840
@@ -852,51 +853,110 @@ class Lambdad(MapTransform):
852853
See also: :py:class:`monai.transforms.compose.MapTransform`
853854
func: Lambda/function to be applied. It also can be a sequence of Callable,
854855
each element corresponds to a key in ``keys``.
856+
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
857+
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
855858
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
856859
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
857860
allow_missing_keys: don't raise exception if key is missing.
861+
862+
Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
863+
image's original size. If need these complicated information, please write a new InvertibleTransform directly.
864+
858865
"""
859866

860867
def __init__(
861868
self,
862869
keys: KeysCollection,
863870
func: Union[Sequence[Callable], Callable],
871+
inv_func: Union[Sequence[Callable], Callable] = no_collation,
864872
overwrite: Union[Sequence[bool], bool] = True,
865873
allow_missing_keys: bool = False,
866874
) -> None:
867875
super().__init__(keys, allow_missing_keys)
868876
self.func = ensure_tuple_rep(func, len(self.keys))
877+
self.inv_func = ensure_tuple_rep(inv_func, len(self.keys))
869878
self.overwrite = ensure_tuple_rep(overwrite, len(self.keys))
870879
self._lambd = Lambda()
871880

881+
def _transform(self, data: Any, func: Callable):
882+
return self._lambd(data, func=func)
883+
872884
def __call__(self, data):
873885
d = dict(data)
874886
for key, func, overwrite in self.key_iterator(d, self.func, self.overwrite):
875-
ret = self._lambd(d[key], func=func)
887+
ret = self._transform(data=d[key], func=func)
888+
if overwrite:
889+
d[key] = ret
890+
self.push_transform(d, key)
891+
return d
892+
893+
def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable):
894+
return self._lambd(data, func=func)
895+
896+
def inverse(self, data):
897+
d = deepcopy(dict(data))
898+
for key, inv_func, overwrite in self.key_iterator(d, self.inv_func, self.overwrite):
899+
transform = self.get_most_recent_transform(d, key)
900+
ret = self._inverse_transform(transform_info=transform, data=d[key], func=inv_func)
876901
if overwrite:
877902
d[key] = ret
903+
self.pop_transform(d, key)
878904
return d
879905

880906

881-
class RandLambdad(Lambdad, Randomizable):
907+
class RandLambdad(Lambdad, RandomizableTransform):
882908
"""
883-
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
884-
It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
909+
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` may contain random logic,
910+
or randomly execute the function based on `prob`. so `CacheDataset` will not execute it and cache the results.
885911
886912
Args:
887913
keys: keys of the corresponding items to be transformed.
888914
See also: :py:class:`monai.transforms.compose.MapTransform`
889915
func: Lambda/function to be applied. It also can be a sequence of Callable,
890916
each element corresponds to a key in ``keys``.
917+
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
918+
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
891919
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
892920
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
921+
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
922+
note that all the data specified by `keys` will share the same random probability to execute or not.
923+
allow_missing_keys: don't raise exception if key is missing.
893924
894925
For more details, please check :py:class:`monai.transforms.Lambdad`.
895926
927+
Note: The inverse operation doesn't allow to define `extra_info` or access other information, such as the
928+
image's original size. If need these complicated information, please write a new InvertibleTransform directly.
929+
896930
"""
897931

898-
def randomize(self, data: Any) -> None:
899-
pass
932+
def __init__(
933+
self,
934+
keys: KeysCollection,
935+
func: Union[Sequence[Callable], Callable],
936+
inv_func: Union[Sequence[Callable], Callable] = no_collation,
937+
overwrite: Union[Sequence[bool], bool] = True,
938+
prob: float = 1.0,
939+
allow_missing_keys: bool = False,
940+
) -> None:
941+
Lambdad.__init__(
942+
self=self,
943+
keys=keys,
944+
func=func,
945+
inv_func=inv_func,
946+
overwrite=overwrite,
947+
allow_missing_keys=allow_missing_keys,
948+
)
949+
RandomizableTransform.__init__(self=self, prob=prob, do_transform=True)
950+
951+
def _transform(self, data: Any, func: Callable):
952+
return self._lambd(data, func=func) if self._do_transform else data
953+
954+
def __call__(self, data):
955+
self.randomize(data)
956+
return super().__call__(data)
957+
958+
def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable):
959+
return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data
900960

901961

902962
class LabelToMaskd(MapTransform):

tests/test_inverse.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535
DivisiblePadd,
3636
Flipd,
3737
InvertibleTransform,
38+
Lambdad,
3839
LoadImaged,
3940
Orientationd,
4041
RandAffined,
4142
RandAxisFlipd,
4243
RandCropByLabelClassesd,
4344
RandCropByPosNegLabeld,
4445
RandFlipd,
46+
RandLambdad,
4547
Randomizable,
4648
RandRotate90d,
4749
RandRotated,
@@ -314,6 +316,16 @@
314316

315317
TESTS.append(("Resized longest 3d", "3D", 5e-2, Resized(KEYS, 201, "longest", "trilinear", True)))
316318

319+
TESTS.append(("Lambdad 2d", "2D", 5e-2, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True)))
320+
321+
TESTS.append(
322+
(
323+
"RandLambdad 3d",
324+
"3D",
325+
5e-2,
326+
RandLambdad(KEYS, func=lambda x: x * 10, inv_func=lambda x: x / 10, overwrite=True, prob=0.5),
327+
)
328+
)
317329

318330
TESTS.append(
319331
(

tests/test_rand_lambda.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
16+
from monai.transforms.transform import Randomizable
17+
from monai.transforms.utility.array import RandLambda
18+
19+
20+
class RandTest(Randomizable):
21+
"""
22+
randomisable transform for testing.
23+
"""
24+
25+
def randomize(self, data=None):
26+
self._a = self.R.random()
27+
28+
def __call__(self, data):
29+
self.randomize()
30+
return data + self._a
31+
32+
33+
class TestRandLambda(unittest.TestCase):
34+
def test_rand_lambdad_identity(self):
35+
img = np.zeros((10, 10))
36+
37+
test_func = RandTest()
38+
test_func.set_random_state(seed=134)
39+
expected = test_func(img)
40+
test_func.set_random_state(seed=134)
41+
ret = RandLambda(func=test_func)(img)
42+
np.testing.assert_allclose(expected, ret)
43+
ret = RandLambda(func=test_func, prob=0.0)(img)
44+
np.testing.assert_allclose(img, ret)
45+
46+
trans = RandLambda(func=test_func, prob=0.5)
47+
trans.set_random_state(seed=123)
48+
ret = trans(img)
49+
np.testing.assert_allclose(img, ret)
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

tests/test_rand_lambdad.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def test_rand_lambdad_identity(self):
4242
ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data)
4343
np.testing.assert_allclose(expected["img"], ret["img"])
4444
np.testing.assert_allclose(expected["prop"], ret["prop"])
45+
ret = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.0)(data)
46+
np.testing.assert_allclose(data["img"], ret["img"])
47+
np.testing.assert_allclose(data["prop"], ret["prop"])
48+
49+
trans = RandLambdad(keys=["img", "prop"], func=test_func, prob=0.5)
50+
trans.set_random_state(seed=123)
51+
ret = trans(data)
52+
np.testing.assert_allclose(data["img"], ret["img"])
53+
np.testing.assert_allclose(data["prop"], ret["prop"])
4554

4655

4756
if __name__ == "__main__":

0 commit comments

Comments
 (0)