Skip to content

Commit 19cc6f0

Browse files
slicepastepre-commit-ci[bot]KumoLiuericspodsurajpaib
authored
Make MetaTensor optional printed in DataStats and DataStatsd #5905 (#7814)
Fixes #5905 ### Description We simply add one argument for DataStats and DataStatsd to make MetaTensor optional printed. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wei_Chuan, Chiang <[email protected]> Signed-off-by: YunLiu <[email protected]> Signed-off-by: Suraj Pai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: Suraj Pai <[email protected]> Co-authored-by: Ben Murray <[email protected]>
1 parent 4e70bf6 commit 19cc6f0

File tree

4 files changed

+123
-7
lines changed

4 files changed

+123
-7
lines changed

monai/transforms/utility/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ def __init__(
656656
data_shape: bool = True,
657657
value_range: bool = True,
658658
data_value: bool = False,
659+
meta_info: bool = False,
659660
additional_info: Callable | None = None,
660661
name: str = "DataStats",
661662
) -> None:
@@ -667,6 +668,7 @@ def __init__(
667668
value_range: whether to show the value range of input data.
668669
data_value: whether to show the raw value of input data.
669670
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
671+
meta_info: whether to show the data of MetaTensor.
670672
additional_info: user can define callable function to extract additional info from input data.
671673
name: identifier of `logging.logger` to use, defaulting to "DataStats".
672674
@@ -681,6 +683,7 @@ def __init__(
681683
self.data_shape = data_shape
682684
self.value_range = value_range
683685
self.data_value = data_value
686+
self.meta_info = meta_info
684687
if additional_info is not None and not callable(additional_info):
685688
raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.")
686689
self.additional_info = additional_info
@@ -707,6 +710,7 @@ def __call__(
707710
data_shape: bool | None = None,
708711
value_range: bool | None = None,
709712
data_value: bool | None = None,
713+
meta_info: bool | None = None,
710714
additional_info: Callable | None = None,
711715
) -> NdarrayOrTensor:
712716
"""
@@ -727,6 +731,9 @@ def __call__(
727731
lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})")
728732
if self.data_value if data_value is None else data_value:
729733
lines.append(f"Value: {img}")
734+
if self.meta_info if meta_info is None else meta_info:
735+
metadata = getattr(img, "meta", "(input is not a MetaTensor)")
736+
lines.append(f"Meta info: {repr(metadata)}")
730737
additional_info = self.additional_info if additional_info is None else additional_info
731738
if additional_info is not None:
732739
lines.append(f"Additional info: {additional_info(img)}")

monai/transforms/utility/dictionary.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,7 @@ def __init__(
793793
data_shape: Sequence[bool] | bool = True,
794794
value_range: Sequence[bool] | bool = True,
795795
data_value: Sequence[bool] | bool = False,
796+
meta_info: Sequence[bool] | bool = False,
796797
additional_info: Sequence[Callable] | Callable | None = None,
797798
name: str = "DataStats",
798799
allow_missing_keys: bool = False,
@@ -812,6 +813,8 @@ def __init__(
812813
data_value: whether to show the raw value of input data.
813814
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
814815
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
816+
meta_info: whether to show the data of MetaTensor.
817+
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
815818
additional_info: user can define callable function to extract
816819
additional info from input data. it also can be a sequence of string, each element
817820
corresponds to a key in ``keys``.
@@ -825,15 +828,34 @@ def __init__(
825828
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
826829
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
827830
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
831+
self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))
828832
self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))
829833
self.printer = DataStats(name=name)
830834

831835
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
832836
d = dict(data)
833-
for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator(
834-
d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info
837+
for (
838+
key,
839+
prefix,
840+
data_type,
841+
data_shape,
842+
value_range,
843+
data_value,
844+
meta_info,
845+
additional_info,
846+
) in self.key_iterator(
847+
d,
848+
self.prefix,
849+
self.data_type,
850+
self.data_shape,
851+
self.value_range,
852+
self.data_value,
853+
self.meta_info,
854+
self.additional_info,
835855
):
836-
d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info)
856+
d[key] = self.printer(
857+
d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info
858+
)
837859
return d
838860

839861

tests/test_data_stats.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424
from parameterized import parameterized
2525

26+
from monai.data.meta_tensor import MetaTensor
2627
from monai.transforms import DataStats
2728

2829
TEST_CASE_1 = [
@@ -130,20 +131,55 @@
130131
]
131132

132133
TEST_CASE_8 = [
134+
{
135+
"prefix": "test data",
136+
"data_type": True,
137+
"data_shape": True,
138+
"value_range": True,
139+
"data_value": True,
140+
"additional_info": np.mean,
141+
"name": "DataStats",
142+
},
133143
np.array([[0, 1], [1, 2]]),
134144
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
135145
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
136146
]
137147

148+
TEST_CASE_9 = [
149+
np.array([[0, 1], [1, 2]]),
150+
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
151+
"Value: [[0 1]\n [1 2]]\n"
152+
"Meta info: '(input is not a MetaTensor)'\n"
153+
"Additional info: 1.0\n",
154+
]
155+
156+
TEST_CASE_10 = [
157+
MetaTensor(
158+
torch.tensor([[0, 1], [1, 2]]),
159+
affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
160+
meta={"some": "info"},
161+
),
162+
"test data statistics:\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\n"
163+
"Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
164+
"Value: tensor([[0, 1],\n [1, 2]])\n"
165+
"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
166+
" [0., 2., 0., 0.],\n"
167+
" [0., 0., 2., 0.],\n"
168+
" [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
169+
"Additional info: 1.0\n",
170+
]
171+
138172

139173
class TestDataStats(unittest.TestCase):
140174

141-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
175+
@parameterized.expand(
176+
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
177+
)
142178
def test_value(self, input_param, input_data, expected_print):
143179
transform = DataStats(**input_param)
144180
_ = transform(input_data)
145181

146-
@parameterized.expand([TEST_CASE_8])
182+
@parameterized.expand([TEST_CASE_9, TEST_CASE_10])
147183
def test_file(self, input_data, expected_print):
148184
with tempfile.TemporaryDirectory() as tempdir:
149185
filename = os.path.join(tempdir, "test_data_stats.log")
@@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print):
158194
"data_shape": True,
159195
"value_range": True,
160196
"data_value": True,
197+
"meta_info": True,
161198
"additional_info": np.mean,
162199
"name": name,
163200
}

tests/test_data_statsd.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
from parameterized import parameterized
2323

24+
from monai.data.meta_tensor import MetaTensor
2425
from monai.transforms import DataStatsd
2526

2627
TEST_CASE_1 = [
@@ -150,22 +151,70 @@
150151
]
151152

152153
TEST_CASE_9 = [
154+
{
155+
"keys": "img",
156+
"prefix": "test data",
157+
"data_shape": True,
158+
"value_range": True,
159+
"data_value": True,
160+
"meta_info": False,
161+
"additional_info": np.mean,
162+
"name": "DataStats",
163+
},
153164
{"img": np.array([[0, 1], [1, 2]])},
154165
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
155166
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
156167
]
157168

169+
TEST_CASE_10 = [
170+
{"img": np.array([[0, 1], [1, 2]])},
171+
"test data statistics:\nType: <class 'numpy.ndarray'> int64\nShape: (2, 2)\nValue range: (0, 2)\n"
172+
"Value: [[0 1]\n [1 2]]\n"
173+
"Meta info: '(input is not a MetaTensor)'\n"
174+
"Additional info: 1.0\n",
175+
]
176+
177+
TEST_CASE_11 = [
178+
{
179+
"img": (
180+
MetaTensor(
181+
torch.tensor([[0, 1], [1, 2]]),
182+
affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
183+
meta={"some": "info"},
184+
)
185+
)
186+
},
187+
"test data statistics:\nType: <class 'monai.data.meta_tensor.MetaTensor'> torch.int64\n"
188+
"Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
189+
"Value: tensor([[0, 1],\n [1, 2]])\n"
190+
"Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
191+
" [0., 2., 0., 0.],\n"
192+
" [0., 0., 2., 0.],\n"
193+
" [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
194+
"Additional info: 1.0\n",
195+
]
196+
158197

159198
class TestDataStatsd(unittest.TestCase):
160199

161200
@parameterized.expand(
162-
[TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
201+
[
202+
TEST_CASE_1,
203+
TEST_CASE_2,
204+
TEST_CASE_3,
205+
TEST_CASE_4,
206+
TEST_CASE_5,
207+
TEST_CASE_6,
208+
TEST_CASE_7,
209+
TEST_CASE_8,
210+
TEST_CASE_9,
211+
]
163212
)
164213
def test_value(self, input_param, input_data, expected_print):
165214
transform = DataStatsd(**input_param)
166215
_ = transform(input_data)
167216

168-
@parameterized.expand([TEST_CASE_9])
217+
@parameterized.expand([TEST_CASE_10, TEST_CASE_11])
169218
def test_file(self, input_data, expected_print):
170219
with tempfile.TemporaryDirectory() as tempdir:
171220
filename = os.path.join(tempdir, "test_stats.log")
@@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print):
180229
"data_shape": True,
181230
"value_range": True,
182231
"data_value": True,
232+
"meta_info": True,
183233
"additional_info": np.mean,
184234
"name": name,
185235
}

0 commit comments

Comments
 (0)