Skip to content

Commit fe559e5

Browse files
authored
Normalize intensity (#2831)
* all close Signed-off-by: Richard Brown <[email protected]> * assert_allclose Signed-off-by: Richard Brown <[email protected]> * NormalizeIntensity Signed-off-by: Richard Brown <[email protected]>
1 parent b4def1a commit fe559e5

File tree

4 files changed

+189
-95
lines changed

4 files changed

+189
-95
lines changed

monai/transforms/intensity/array.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -539,10 +539,12 @@ class NormalizeIntensity(Transform):
539539
dtype: output data type, defaults to float32.
540540
"""
541541

542+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
543+
542544
def __init__(
543545
self,
544-
subtrahend: Union[Sequence, np.ndarray, None] = None,
545-
divisor: Union[Sequence, np.ndarray, None] = None,
546+
subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
547+
divisor: Union[Sequence, NdarrayOrTensor, None] = None,
546548
nonzero: bool = False,
547549
channel_wise: bool = False,
548550
dtype: DtypeLike = np.float32,
@@ -553,26 +555,51 @@ def __init__(
553555
self.channel_wise = channel_wise
554556
self.dtype = dtype
555557

556-
def _normalize(self, img: np.ndarray, sub=None, div=None) -> np.ndarray:
557-
slices = (img != 0) if self.nonzero else np.ones(img.shape, dtype=bool)
558-
if not np.any(slices):
558+
@staticmethod
559+
def _mean(x):
560+
if isinstance(x, np.ndarray):
561+
return np.mean(x)
562+
x = torch.mean(x.float())
563+
return x.item() if x.numel() == 1 else x
564+
565+
@staticmethod
566+
def _std(x):
567+
if isinstance(x, np.ndarray):
568+
return np.std(x)
569+
x = torch.std(x.float(), unbiased=False)
570+
return x.item() if x.numel() == 1 else x
571+
572+
def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
573+
img, *_ = convert_data_type(img, dtype=torch.float32)
574+
575+
if self.nonzero:
576+
slices = img != 0
577+
else:
578+
if isinstance(img, np.ndarray):
579+
slices = np.ones_like(img, dtype=bool)
580+
else:
581+
slices = torch.ones_like(img, dtype=torch.bool)
582+
if not slices.any():
559583
return img
560584

561-
_sub = sub if sub is not None else np.mean(img[slices])
562-
if isinstance(_sub, np.ndarray):
585+
_sub = sub if sub is not None else self._mean(img[slices])
586+
if isinstance(_sub, (torch.Tensor, np.ndarray)):
587+
_sub, *_ = convert_to_dst_type(_sub, img)
563588
_sub = _sub[slices]
564589

565-
_div = div if div is not None else np.std(img[slices])
590+
_div = div if div is not None else self._std(img[slices])
566591
if np.isscalar(_div):
567592
if _div == 0.0:
568593
_div = 1.0
569-
elif isinstance(_div, np.ndarray):
594+
elif isinstance(_div, (torch.Tensor, np.ndarray)):
595+
_div, *_ = convert_to_dst_type(_div, img)
570596
_div = _div[slices]
571597
_div[_div == 0.0] = 1.0
598+
572599
img[slices] = (img[slices] - _sub) / _div
573600
return img
574601

575-
def __call__(self, img: np.ndarray) -> np.ndarray:
602+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
576603
"""
577604
Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
578605
"""
@@ -583,15 +610,16 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
583610
raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")
584611

585612
for i, d in enumerate(img):
586-
img[i] = self._normalize(
613+
img[i] = self._normalize( # type: ignore
587614
d,
588615
sub=self.subtrahend[i] if self.subtrahend is not None else None,
589616
div=self.divisor[i] if self.divisor is not None else None,
590617
)
591618
else:
592619
img = self._normalize(img, self.subtrahend, self.divisor)
593620

594-
return img.astype(self.dtype)
621+
out, *_ = convert_data_type(img, dtype=self.dtype)
622+
return out
595623

596624

597625
class ThresholdIntensity(Transform):

monai/transforms/intensity/dictionary.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -612,11 +612,13 @@ class NormalizeIntensityd(MapTransform):
612612
allow_missing_keys: don't raise exception if key is missing.
613613
"""
614614

615+
backend = NormalizeIntensity.backend
616+
615617
def __init__(
616618
self,
617619
keys: KeysCollection,
618-
subtrahend: Optional[np.ndarray] = None,
619-
divisor: Optional[np.ndarray] = None,
620+
subtrahend: Optional[NdarrayOrTensor] = None,
621+
divisor: Optional[NdarrayOrTensor] = None,
620622
nonzero: bool = False,
621623
channel_wise: bool = False,
622624
dtype: DtypeLike = np.float32,
@@ -625,7 +627,7 @@ def __init__(
625627
super().__init__(keys, allow_missing_keys)
626628
self.normalizer = NormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype)
627629

628-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
630+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
629631
d = dict(data)
630632
for key in self.key_iterator(d):
631633
d[key] = self.normalizer(d[key])

tests/test_normalize_intensity.py

Lines changed: 90 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,70 +12,111 @@
1212
import unittest
1313

1414
import numpy as np
15+
import torch
1516
from parameterized import parameterized
1617

1718
from monai.transforms import NormalizeIntensity
18-
from tests.utils import NumpyImageTestCase2D
19+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1920

20-
TEST_CASES = [
21-
[{"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])],
22-
[
23-
{"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]), "divisor": np.array([0.5, 0.5, 0.5, 0.5]), "nonzero": True},
24-
np.array([0.0, 3.0, 0.0, 4.0]),
25-
np.array([0.0, -1.0, 0.0, 1.0]),
26-
],
27-
[{"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])],
28-
[{"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])],
29-
[{"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])],
30-
[
31-
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]},
32-
np.ones((3, 2, 2)),
33-
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]),
34-
],
35-
[
36-
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]},
37-
np.ones((3, 2, 2)),
38-
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]),
39-
],
40-
[
41-
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0},
42-
np.ones((3, 2, 2)),
43-
np.ones((3, 2, 2)) * -1.0,
44-
],
45-
[
46-
{"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0},
47-
np.ones((3, 2, 2)),
48-
np.ones((3, 2, 2)) * 0.5,
49-
],
50-
[
51-
{"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]},
52-
np.ones((3, 2, 2)),
53-
np.ones((3, 2, 2)) * 0.5,
54-
],
55-
]
21+
TESTS = []
22+
for p in TEST_NDARRAYS:
23+
TESTS.append([p, {"nonzero": True}, np.array([0.0, 3.0, 0.0, 4.0]), np.array([0.0, -1.0, 0.0, 1.0])])
24+
for q in TEST_NDARRAYS:
25+
for u in TEST_NDARRAYS:
26+
TESTS.append(
27+
[
28+
p,
29+
{
30+
"subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])),
31+
"divisor": u(np.array([0.5, 0.5, 0.5, 0.5])),
32+
"nonzero": True,
33+
},
34+
np.array([0.0, 3.0, 0.0, 4.0]),
35+
np.array([0.0, -1.0, 0.0, 1.0]),
36+
]
37+
)
38+
TESTS.append([p, {"nonzero": True}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])])
39+
TESTS.append([p, {"nonzero": False}, np.array([0.0, 0.0, 0.0, 0.0]), np.array([0.0, 0.0, 0.0, 0.0])])
40+
TESTS.append([p, {"nonzero": False}, np.array([1, 1, 1, 1]), np.array([0.0, 0.0, 0.0, 0.0])])
41+
TESTS.append(
42+
[
43+
p,
44+
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]},
45+
np.ones((3, 2, 2)),
46+
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]]),
47+
]
48+
)
49+
TESTS.append(
50+
[
51+
p,
52+
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]},
53+
np.ones((3, 2, 2)),
54+
np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]]),
55+
]
56+
)
57+
TESTS.append(
58+
[
59+
p,
60+
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0},
61+
np.ones((3, 2, 2)),
62+
np.ones((3, 2, 2)) * -1.0,
63+
]
64+
)
65+
TESTS.append(
66+
[
67+
p,
68+
{"nonzero": True, "channel_wise": False, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": 0},
69+
np.ones((3, 2, 2)),
70+
np.ones((3, 2, 2)) * 0.5,
71+
]
72+
)
73+
TESTS.append(
74+
[
75+
p,
76+
{"nonzero": True, "channel_wise": True, "subtrahend": np.ones((3, 2, 2)) * 0.5, "divisor": [0, 1, 0]},
77+
np.ones((3, 2, 2)),
78+
np.ones((3, 2, 2)) * 0.5,
79+
]
80+
)
5681

5782

5883
class TestNormalizeIntensity(NumpyImageTestCase2D):
59-
def test_default(self):
84+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
85+
def test_default(self, im_type):
86+
im = im_type(self.imt.copy())
6087
normalizer = NormalizeIntensity()
61-
normalized = normalizer(self.imt.copy())
62-
self.assertTrue(normalized.dtype == np.float32)
88+
normalized = normalizer(im)
89+
self.assertEqual(type(im), type(normalized))
90+
if isinstance(normalized, torch.Tensor):
91+
self.assertEqual(im.device, normalized.device)
92+
self.assertTrue(normalized.dtype in (np.float32, torch.float32))
6393
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
64-
np.testing.assert_allclose(normalized, expected, rtol=1e-3)
94+
assert_allclose(expected, normalized, rtol=1e-3)
6595

66-
@parameterized.expand(TEST_CASES)
67-
def test_nonzero(self, input_param, input_data, expected_data):
96+
@parameterized.expand(TESTS)
97+
def test_nonzero(self, in_type, input_param, input_data, expected_data):
6898
normalizer = NormalizeIntensity(**input_param)
69-
np.testing.assert_allclose(expected_data, normalizer(input_data))
99+
im = in_type(input_data)
100+
normalized = normalizer(im)
101+
self.assertEqual(type(im), type(normalized))
102+
if isinstance(normalized, torch.Tensor):
103+
self.assertEqual(im.device, normalized.device)
104+
assert_allclose(expected_data, normalized)
70105

71-
def test_channel_wise(self):
106+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
107+
def test_channel_wise(self, im_type):
72108
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True)
73-
input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])
109+
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
74110
expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])
75-
np.testing.assert_allclose(expected, normalizer(input_data))
111+
normalized = normalizer(input_data)
112+
self.assertEqual(type(input_data), type(normalized))
113+
if isinstance(normalized, torch.Tensor):
114+
self.assertEqual(input_data.device, normalized.device)
115+
assert_allclose(expected, normalized)
76116

77-
def test_value_errors(self):
78-
input_data = np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])
117+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
118+
def test_value_errors(self, im_type):
119+
input_data = im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))
79120
normalizer = NormalizeIntensity(nonzero=True, channel_wise=True, subtrahend=[1])
80121
with self.assertRaises(ValueError):
81122
normalizer(input_data)

tests/test_normalize_intensityd.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,54 +12,77 @@
1212
import unittest
1313

1414
import numpy as np
15+
import torch
1516
from parameterized import parameterized
1617

1718
from monai.transforms import NormalizeIntensityd
18-
from tests.utils import NumpyImageTestCase2D
19+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1920

20-
TEST_CASE_1 = [
21-
{"keys": ["img"], "nonzero": True},
22-
{"img": np.array([0.0, 3.0, 0.0, 4.0])},
23-
np.array([0.0, -1.0, 0.0, 1.0]),
24-
]
25-
26-
TEST_CASE_2 = [
27-
{
28-
"keys": ["img"],
29-
"subtrahend": np.array([3.5, 3.5, 3.5, 3.5]),
30-
"divisor": np.array([0.5, 0.5, 0.5, 0.5]),
31-
"nonzero": True,
32-
},
33-
{"img": np.array([0.0, 3.0, 0.0, 4.0])},
34-
np.array([0.0, -1.0, 0.0, 1.0]),
35-
]
36-
37-
TEST_CASE_3 = [
38-
{"keys": ["img"], "nonzero": True},
39-
{"img": np.array([0.0, 0.0, 0.0, 0.0])},
40-
np.array([0.0, 0.0, 0.0, 0.0]),
41-
]
21+
TESTS = []
22+
for p in TEST_NDARRAYS:
23+
for q in TEST_NDARRAYS:
24+
TESTS.append(
25+
[
26+
{"keys": ["img"], "nonzero": True},
27+
{"img": p(np.array([0.0, 3.0, 0.0, 4.0]))},
28+
np.array([0.0, -1.0, 0.0, 1.0]),
29+
]
30+
)
31+
TESTS.append(
32+
[
33+
{
34+
"keys": ["img"],
35+
"subtrahend": q(np.array([3.5, 3.5, 3.5, 3.5])),
36+
"divisor": q(np.array([0.5, 0.5, 0.5, 0.5])),
37+
"nonzero": True,
38+
},
39+
{"img": p(np.array([0.0, 3.0, 0.0, 4.0]))},
40+
np.array([0.0, -1.0, 0.0, 1.0]),
41+
]
42+
)
43+
TESTS.append(
44+
[
45+
{"keys": ["img"], "nonzero": True},
46+
{"img": p(np.array([0.0, 0.0, 0.0, 0.0]))},
47+
np.array([0.0, 0.0, 0.0, 0.0]),
48+
]
49+
)
4250

4351

4452
class TestNormalizeIntensityd(NumpyImageTestCase2D):
45-
def test_image_normalize_intensityd(self):
53+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
54+
def test_image_normalize_intensityd(self, im_type):
4655
key = "img"
56+
im = im_type(self.imt)
4757
normalizer = NormalizeIntensityd(keys=[key])
48-
normalized = normalizer({key: self.imt})
58+
normalized = normalizer({key: im})[key]
4959
expected = (self.imt - np.mean(self.imt)) / np.std(self.imt)
50-
np.testing.assert_allclose(normalized[key], expected, rtol=1e-3)
60+
self.assertEqual(type(im), type(normalized))
61+
if isinstance(normalized, torch.Tensor):
62+
self.assertEqual(im.device, normalized.device)
63+
assert_allclose(normalized, expected, rtol=1e-3)
5164

52-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
65+
@parameterized.expand(TESTS)
5366
def test_nonzero(self, input_param, input_data, expected_data):
67+
key = "img"
5468
normalizer = NormalizeIntensityd(**input_param)
55-
np.testing.assert_allclose(expected_data, normalizer(input_data)["img"])
69+
normalized = normalizer(input_data)[key]
70+
self.assertEqual(type(input_data[key]), type(normalized))
71+
if isinstance(normalized, torch.Tensor):
72+
self.assertEqual(input_data[key].device, normalized.device)
73+
assert_allclose(normalized, expected_data)
5674

57-
def test_channel_wise(self):
75+
@parameterized.expand([[p] for p in TEST_NDARRAYS])
76+
def test_channel_wise(self, im_type):
5877
key = "img"
5978
normalizer = NormalizeIntensityd(keys=key, nonzero=True, channel_wise=True)
60-
input_data = {key: np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]])}
79+
input_data = {key: im_type(np.array([[0.0, 3.0, 0.0, 4.0], [0.0, 4.0, 0.0, 5.0]]))}
80+
normalized = normalizer(input_data)[key]
81+
self.assertEqual(type(input_data[key]), type(normalized))
82+
if isinstance(normalized, torch.Tensor):
83+
self.assertEqual(input_data[key].device, normalized.device)
6184
expected = np.array([[0.0, -1.0, 0.0, 1.0], [0.0, -1.0, 0.0, 1.0]])
62-
np.testing.assert_allclose(expected, normalizer(input_data)[key])
85+
assert_allclose(normalized, expected)
6386

6487

6588
if __name__ == "__main__":

0 commit comments

Comments
 (0)