Skip to content

Commit 51c4387

Browse files
authored
Merge branch 'dev' into restructure_transforms
2 parents 8645afc + 390fe7f commit 51c4387

File tree

5 files changed

+278
-0
lines changed

5 files changed

+278
-0
lines changed

docs/source/data.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ DistributedWeightedRandomSampler
182182
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183183
.. autoclass:: monai.data.DistributedWeightedRandomSampler
184184

185+
DatasetSummary
186+
~~~~~~~~~~~~~~
187+
.. autoclass:: monai.data.DatasetSummary
188+
185189
Decathlon Datalist
186190
~~~~~~~~~~~~~~~~~~
187191
.. autofunction:: monai.data.load_decathlon_datalist

monai/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
SmartCacheDataset,
2424
ZipDataset,
2525
)
26+
from .dataset_summary import DatasetSummary
2627
from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties
2728
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
2829
from .image_dataset import ImageDataset

monai/data/dataset_summary.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from itertools import chain
13+
from typing import List, Optional
14+
15+
import numpy as np
16+
import torch
17+
18+
from monai.data.dataloader import DataLoader
19+
from monai.data.dataset import Dataset
20+
21+
22+
class DatasetSummary:
23+
"""
24+
This class provides a way to calculate a reasonable output voxel spacing according to
25+
the input dataset. The achieved values can used to resample the input in 3d segmentation tasks
26+
(like using as the `pixdim` parameter in `monai.transforms.Spacingd`).
27+
In addition, it also supports to count the mean, std, min and max intensities of the input,
28+
and these statistics are helpful for image normalization
29+
(like using in `monai.transforms.ScaleIntensityRanged` and `monai.transforms.NormalizeIntensityd`).
30+
31+
The algorithm for calculation refers to:
32+
`Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
33+
34+
"""
35+
36+
def __init__(
37+
self,
38+
dataset: Dataset,
39+
image_key: Optional[str] = "image",
40+
label_key: Optional[str] = "label",
41+
meta_key_postfix: str = "meta_dict",
42+
num_workers: int = 0,
43+
**kwargs,
44+
):
45+
"""
46+
Args:
47+
dataset: dataset from which to load the data.
48+
image_key: key name of images (default: ``image``).
49+
label_key: key name of labels (default: ``label``).
50+
meta_key_postfix: use `{image_key}_{meta_key_postfix}` to fetch the meta data from dict,
51+
the meta data is a dictionary object (default: ``meta_dict``).
52+
num_workers: how many subprocesses to use for data loading.
53+
``0`` means that the data will be loaded in the main process (default: ``0``).
54+
kwargs: other parameters (except batch_size) for DataLoader (this class forces to use ``batch_size=1``).
55+
56+
"""
57+
58+
self.data_loader = DataLoader(dataset=dataset, batch_size=1, num_workers=num_workers, **kwargs)
59+
60+
self.image_key = image_key
61+
self.label_key = label_key
62+
if image_key:
63+
self.meta_key = "{}_{}".format(image_key, meta_key_postfix)
64+
self.all_meta_data: List = []
65+
66+
def collect_meta_data(self):
67+
"""
68+
This function is used to collect the meta data for all images of the dataset.
69+
"""
70+
if not self.meta_key:
71+
raise ValueError("To collect meta data for the dataset, `meta_key` should exist.")
72+
73+
for data in self.data_loader:
74+
self.all_meta_data.append(data[self.meta_key])
75+
76+
def get_target_spacing(self, spacing_key: str = "pixdim", anisotropic_threshold: int = 3, percentile: float = 10.0):
77+
"""
78+
Calculate the target spacing according to all spacings.
79+
If the target spacing is very anisotropic,
80+
decrease the spacing value of the maximum axis according to percentile.
81+
So far, this function only supports NIFTI images which store spacings in headers with key "pixdim". After loading
82+
with `monai.DataLoader`, "pixdim" is in the form of `torch.Tensor` with size `(batch_size, 8)`.
83+
84+
Args:
85+
spacing_key: key of spacing in meta data (default: ``pixdim``).
86+
anisotropic_threshold: threshold to decide if the target spacing is anisotropic (default: ``3``).
87+
percentile: for anisotropic target spacing, use the percentile of all spacings of the anisotropic axis to
88+
replace that axis.
89+
90+
"""
91+
if len(self.all_meta_data) == 0:
92+
self.collect_meta_data()
93+
if spacing_key not in self.all_meta_data[0]:
94+
raise ValueError("The provided spacing_key is not in self.all_meta_data.")
95+
96+
all_spacings = torch.cat([data[spacing_key][:, 1:4] for data in self.all_meta_data], dim=0).numpy()
97+
98+
target_spacing = np.median(all_spacings, axis=0)
99+
if max(target_spacing) / min(target_spacing) >= anisotropic_threshold:
100+
largest_axis = np.argmax(target_spacing)
101+
target_spacing[largest_axis] = np.percentile(all_spacings[:, largest_axis], percentile)
102+
103+
output = list(target_spacing)
104+
105+
return tuple(output)
106+
107+
def calculate_statistics(self, foreground_threshold: int = 0):
108+
"""
109+
This function is used to calculate the maximum, minimum, mean and standard deviation of intensities of
110+
the input dataset.
111+
112+
Args:
113+
foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter
114+
is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding
115+
voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set
116+
the threshold to ``-1`` (default: ``0``).
117+
118+
"""
119+
voxel_sum = torch.as_tensor(0.0)
120+
voxel_square_sum = torch.as_tensor(0.0)
121+
voxel_max, voxel_min = [], []
122+
voxel_ct = 0
123+
124+
for data in self.data_loader:
125+
if self.image_key and self.label_key:
126+
image, label = data[self.image_key], data[self.label_key]
127+
else:
128+
image, label = data
129+
130+
voxel_max.append(image.max().item())
131+
voxel_min.append(image.min().item())
132+
133+
image_foreground = image[torch.where(label > foreground_threshold)]
134+
voxel_ct += len(image_foreground)
135+
voxel_sum += image_foreground.sum()
136+
voxel_square_sum += torch.square(image_foreground).sum()
137+
138+
self.data_max, self.data_min = max(voxel_max), min(voxel_min)
139+
self.data_mean = (voxel_sum / voxel_ct).item()
140+
self.data_std = (torch.sqrt(voxel_square_sum / voxel_ct - self.data_mean ** 2)).item()
141+
142+
def calculate_percentiles(
143+
self,
144+
foreground_threshold: int = 0,
145+
sampling_flag: bool = True,
146+
interval: int = 10,
147+
min_percentile: float = 0.5,
148+
max_percentile: float = 99.5,
149+
):
150+
"""
151+
This function is used to calculate the percentiles of intensities (and median) of the input dataset. To get
152+
the required values, all voxels need to be accumulated. To reduce the memory used, this function can be set
153+
to accumulate only a part of the voxels.
154+
155+
Args:
156+
foreground_threshold: the threshold to distinguish if a voxel belongs to foreground, this parameter
157+
is used to select the foreground of images for calculation. Normally, `label > 0` means the corresponding
158+
voxel belongs to foreground, thus if you need to calculate the statistics for whole images, you can set
159+
the threshold to ``-1`` (default: ``0``).
160+
sampling_flag: whether to sample only a part of the voxels (default: ``True``).
161+
interval: the sampling interval for accumulating voxels (default: ``10``).
162+
min_percentile: minimal percentile (default: ``0.5``).
163+
max_percentile: maximal percentile (default: ``99.5``).
164+
165+
"""
166+
all_intensities = []
167+
for data in self.data_loader:
168+
if self.image_key and self.label_key:
169+
image, label = data[self.image_key], data[self.label_key]
170+
else:
171+
image, label = data
172+
173+
intensities = image[torch.where(label > foreground_threshold)].tolist()
174+
if sampling_flag:
175+
intensities = intensities[::interval]
176+
all_intensities.append(intensities)
177+
178+
all_intensities = list(chain(*all_intensities))
179+
self.data_min_percentile, self.data_max_percentile = np.percentile(
180+
all_intensities, [min_percentile, max_percentile]
181+
)
182+
self.data_median = np.median(all_intensities)

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def run_testsuit():
111111
"test_handler_metrics_saver",
112112
"test_handler_metrics_saver_dist",
113113
"test_handler_classification_saver_dist",
114+
"test_dataset_summary",
114115
"test_deepgrow_transforms",
115116
"test_deepgrow_interaction",
116117
"test_deepgrow_dataset",

tests/test_dataset_summary.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import glob
13+
import os
14+
import tempfile
15+
import unittest
16+
17+
import nibabel as nib
18+
import numpy as np
19+
20+
from monai.data import Dataset, DatasetSummary, create_test_image_3d
21+
from monai.transforms import LoadImaged
22+
from monai.utils import set_determinism
23+
24+
25+
class TestDatasetSummary(unittest.TestCase):
26+
def test_spacing_intensity(self):
27+
set_determinism(seed=0)
28+
with tempfile.TemporaryDirectory() as tempdir:
29+
30+
for i in range(5):
31+
im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0)
32+
n = nib.Nifti1Image(im, np.eye(4))
33+
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
34+
n = nib.Nifti1Image(seg, np.eye(4))
35+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
36+
37+
train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz")))
38+
train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz")))
39+
data_dicts = [
40+
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
41+
]
42+
43+
dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"]))
44+
45+
calculator = DatasetSummary(dataset, num_workers=4)
46+
47+
target_spacing = calculator.get_target_spacing()
48+
self.assertEqual(target_spacing, (1.0, 1.0, 1.0))
49+
calculator.calculate_statistics()
50+
np.testing.assert_allclose(calculator.data_mean, 0.892599, rtol=1e-5, atol=1e-5)
51+
np.testing.assert_allclose(calculator.data_std, 0.131731, rtol=1e-5, atol=1e-5)
52+
calculator.calculate_percentiles(sampling_flag=True, interval=2)
53+
self.assertEqual(calculator.data_max_percentile, 1.0)
54+
np.testing.assert_allclose(calculator.data_min_percentile, 0.556411, rtol=1e-5, atol=1e-5)
55+
56+
def test_anisotropic_spacing(self):
57+
with tempfile.TemporaryDirectory() as tempdir:
58+
59+
pixdims = [
60+
[1.0, 1.0, 5.0],
61+
[1.0, 1.0, 4.0],
62+
[1.0, 1.0, 4.5],
63+
[1.0, 1.0, 2.0],
64+
[1.0, 1.0, 1.0],
65+
]
66+
for i in range(5):
67+
im, seg = create_test_image_3d(32, 32, 32, num_seg_classes=1, num_objs=3, rad_max=6, channel_dim=0)
68+
n = nib.Nifti1Image(im, np.eye(4))
69+
n.header["pixdim"][1:4] = pixdims[i]
70+
nib.save(n, os.path.join(tempdir, f"img{i:d}.nii.gz"))
71+
n = nib.Nifti1Image(seg, np.eye(4))
72+
n.header["pixdim"][1:4] = pixdims[i]
73+
nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz"))
74+
75+
train_images = sorted(glob.glob(os.path.join(tempdir, "img*.nii.gz")))
76+
train_labels = sorted(glob.glob(os.path.join(tempdir, "seg*.nii.gz")))
77+
data_dicts = [
78+
{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)
79+
]
80+
81+
dataset = Dataset(data=data_dicts, transform=LoadImaged(keys=["image", "label"]))
82+
83+
calculator = DatasetSummary(dataset, num_workers=4)
84+
85+
target_spacing = calculator.get_target_spacing(anisotropic_threshold=4.0, percentile=20.0)
86+
np.testing.assert_allclose(target_spacing, (1.0, 1.0, 1.8))
87+
88+
89+
if __name__ == "__main__":
90+
unittest.main()

0 commit comments

Comments
 (0)