Skip to content

Commit fa5bc15

Browse files
authored
Scale intensity (#2832)
* all close Signed-off-by: Richard Brown <[email protected]> * assert_allclose Signed-off-by: Richard Brown <[email protected]> * ScaleIntensity Signed-off-by: Richard Brown <[email protected]>
1 parent 38ecaef commit fa5bc15

File tree

9 files changed

+74
-54
lines changed

9 files changed

+74
-54
lines changed

monai/data/synthetic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def create_test_image_2d(
7676
labels = np.ceil(image).astype(np.int32)
7777

7878
norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)
79-
noisyimage = rescale_array(np.maximum(image, norm))
79+
noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore
8080

8181
if channel_dim is not None:
8282
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 2)):
@@ -151,7 +151,7 @@ def create_test_image_3d(
151151
labels = np.ceil(image).astype(np.int32)
152152

153153
norm = rs.uniform(0, num_seg_classes * noise_max, size=image.shape)
154-
noisyimage = rescale_array(np.maximum(image, norm))
154+
noisyimage: np.ndarray = rescale_array(np.maximum(image, norm)) # type: ignore
155155

156156
if channel_dim is not None:
157157
if not (isinstance(channel_dim, int) and channel_dim in (-1, 0, 3)):

monai/transforms/intensity/array.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,8 @@ class ScaleIntensity(Transform):
373373
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
374374
"""
375375

376+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
377+
376378
def __init__(
377379
self, minv: Optional[float] = 0.0, maxv: Optional[float] = 1.0, factor: Optional[float] = None
378380
) -> None:
@@ -387,7 +389,7 @@ def __init__(
387389
self.maxv = maxv
388390
self.factor = factor
389391

390-
def __call__(self, img: np.ndarray) -> np.ndarray:
392+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
391393
"""
392394
Apply the transform to `img`.
393395
@@ -396,9 +398,11 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
396398
397399
"""
398400
if self.minv is not None and self.maxv is not None:
399-
return np.asarray(rescale_array(img, self.minv, self.maxv, img.dtype))
401+
return rescale_array(img, self.minv, self.maxv, img.dtype)
400402
if self.factor is not None:
401-
return np.asarray(img * (1 + self.factor), dtype=img.dtype)
403+
out = img * (1 + self.factor)
404+
out, *_ = convert_data_type(out, dtype=img.dtype)
405+
return out
402406
raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.")
403407

404408

@@ -408,6 +412,8 @@ class RandScaleIntensity(RandomizableTransform):
408412
is randomly picked.
409413
"""
410414

415+
backend = ScaleIntensity.backend
416+
411417
def __init__(self, factors: Union[Tuple[float, float], float], prob: float = 0.1) -> None:
412418
"""
413419
Args:
@@ -429,7 +435,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
429435
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
430436
super().randomize(None)
431437

432-
def __call__(self, img: np.ndarray) -> np.ndarray:
438+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
433439
"""
434440
Apply the transform to `img`.
435441
"""

monai/transforms/intensity/dictionary.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ class ScaleIntensityd(MapTransform):
472472
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
473473
"""
474474

475+
backend = ScaleIntensity.backend
476+
475477
def __init__(
476478
self,
477479
keys: KeysCollection,
@@ -494,7 +496,7 @@ def __init__(
494496
super().__init__(keys, allow_missing_keys)
495497
self.scaler = ScaleIntensity(minv, maxv, factor)
496498

497-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
499+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
498500
d = dict(data)
499501
for key in self.key_iterator(d):
500502
d[key] = self.scaler(d[key])
@@ -506,6 +508,8 @@ class RandScaleIntensityd(RandomizableTransform, MapTransform):
506508
Dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`.
507509
"""
508510

511+
backend = ScaleIntensity.backend
512+
509513
def __init__(
510514
self,
511515
keys: KeysCollection,
@@ -539,7 +543,7 @@ def randomize(self, data: Optional[Any] = None) -> None:
539543
self.factor = self.R.uniform(low=self.factors[0], high=self.factors[1])
540544
super().randomize(None)
541545

542-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
546+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
543547
d = dict(data)
544548
self.randomize()
545549
if not self._do_transform:

monai/transforms/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import monai
2323
import monai.transforms.transform
2424
from monai.config import DtypeLike, IndexSelection
25+
from monai.config.type_definitions import NdarrayOrTensor
2526
from monai.networks.layers import GaussianFilter
2627
from monai.transforms.compose import Compose, OneOf
2728
from monai.transforms.transform import MapTransform, Transform
@@ -37,6 +38,7 @@
3738
min_version,
3839
optional_import,
3940
)
41+
from monai.utils.type_conversion import convert_data_type
4042

4143
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
4244
ndimage, _ = optional_import("scipy.ndimage")
@@ -130,15 +132,17 @@ def zero_margins(img: np.ndarray, margin: int) -> bool:
130132
return not np.any(img[:, :margin, :]) and not np.any(img[:, -margin:, :])
131133

132134

133-
def rescale_array(arr: np.ndarray, minv: float = 0.0, maxv: float = 1.0, dtype: DtypeLike = np.float32):
135+
def rescale_array(
136+
arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32
137+
) -> NdarrayOrTensor:
134138
"""
135139
Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
136140
"""
137141
if dtype is not None:
138-
arr = arr.astype(dtype)
142+
arr, *_ = convert_data_type(arr, dtype=dtype)
139143

140-
mina = np.min(arr)
141-
maxa = np.max(arr)
144+
mina = arr.min()
145+
maxa = arr.max()
142146

143147
if mina == maxa:
144148
return arr * minv

monai/visualize/img2tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def plot_2d_or_3d_image(
188188
d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index
189189

190190
if d.ndim == 2:
191-
d = rescale_array(d, 0, 1)
191+
d = rescale_array(d, 0, 1) # type: ignore
192192
dataformats = "HW"
193193
writer.add_image(f"{tag}_{dataformats}", d, step, dataformats=dataformats)
194194
return

tests/test_rand_scale_intensity.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414
import numpy as np
1515

1616
from monai.transforms import RandScaleIntensity
17-
from tests.utils import NumpyImageTestCase2D
17+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1818

1919

2020
class TestRandScaleIntensity(NumpyImageTestCase2D):
2121
def test_value(self):
22-
scaler = RandScaleIntensity(factors=0.5, prob=1.0)
23-
scaler.set_random_state(seed=0)
24-
result = scaler(self.imt)
25-
np.random.seed(0)
26-
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
27-
np.testing.assert_allclose(result, expected)
22+
for p in TEST_NDARRAYS:
23+
scaler = RandScaleIntensity(factors=0.5, prob=1.0)
24+
scaler.set_random_state(seed=0)
25+
result = scaler(p(self.imt))
26+
np.random.seed(0)
27+
expected = p((self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32))
28+
assert_allclose(result, expected, rtol=1e-7, atol=0)
2829

2930

3031
if __name__ == "__main__":

tests/test_rand_scale_intensityd.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
import numpy as np
1515

1616
from monai.transforms import RandScaleIntensityd
17-
from tests.utils import NumpyImageTestCase2D
17+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1818

1919

2020
class TestRandScaleIntensityd(NumpyImageTestCase2D):
2121
def test_value(self):
22-
key = "img"
23-
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0)
24-
scaler.set_random_state(seed=0)
25-
result = scaler({key: self.imt})
26-
np.random.seed(0)
27-
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
28-
np.testing.assert_allclose(result[key], expected)
22+
for p in TEST_NDARRAYS:
23+
key = "img"
24+
scaler = RandScaleIntensityd(keys=[key], factors=0.5, prob=1.0)
25+
scaler.set_random_state(seed=0)
26+
result = scaler({key: p(self.imt)})
27+
np.random.seed(0)
28+
expected = (self.imt * (1 + np.random.uniform(low=-0.5, high=0.5))).astype(np.float32)
29+
assert_allclose(result[key], expected)
2930

3031

3132
if __name__ == "__main__":

tests/test_scale_intensity.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,26 @@
1414
import numpy as np
1515

1616
from monai.transforms import ScaleIntensity
17-
from tests.utils import NumpyImageTestCase2D
17+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1818

1919

2020
class TestScaleIntensity(NumpyImageTestCase2D):
2121
def test_range_scale(self):
22-
scaler = ScaleIntensity(minv=1.0, maxv=2.0)
23-
result = scaler(self.imt)
24-
mina = np.min(self.imt)
25-
maxa = np.max(self.imt)
26-
norm = (self.imt - mina) / (maxa - mina)
27-
expected = (norm * (2.0 - 1.0)) + 1.0
28-
np.testing.assert_allclose(result, expected)
22+
for p in TEST_NDARRAYS:
23+
scaler = ScaleIntensity(minv=1.0, maxv=2.0)
24+
result = scaler(p(self.imt))
25+
mina = self.imt.min()
26+
maxa = self.imt.max()
27+
norm = (self.imt - mina) / (maxa - mina)
28+
expected = p((norm * (2.0 - 1.0)) + 1.0)
29+
assert_allclose(result, expected, rtol=1e-7, atol=0)
2930

3031
def test_factor_scale(self):
31-
scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1)
32-
result = scaler(self.imt)
33-
expected = (self.imt * (1 + 0.1)).astype(np.float32)
34-
np.testing.assert_allclose(result, expected)
32+
for p in TEST_NDARRAYS:
33+
scaler = ScaleIntensity(minv=None, maxv=None, factor=0.1)
34+
result = scaler(p(self.imt))
35+
expected = p((self.imt * (1 + 0.1)).astype(np.float32))
36+
assert_allclose(result, expected, rtol=1e-7, atol=0)
3537

3638

3739
if __name__ == "__main__":

tests/test_scale_intensityd.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,28 @@
1414
import numpy as np
1515

1616
from monai.transforms import ScaleIntensityd
17-
from tests.utils import NumpyImageTestCase2D
17+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
1818

1919

2020
class TestScaleIntensityd(NumpyImageTestCase2D):
2121
def test_range_scale(self):
22-
key = "img"
23-
scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
24-
result = scaler({key: self.imt})
25-
mina = np.min(self.imt)
26-
maxa = np.max(self.imt)
27-
norm = (self.imt - mina) / (maxa - mina)
28-
expected = (norm * (2.0 - 1.0)) + 1.0
29-
np.testing.assert_allclose(result[key], expected)
22+
for p in TEST_NDARRAYS:
23+
key = "img"
24+
scaler = ScaleIntensityd(keys=[key], minv=1.0, maxv=2.0)
25+
result = scaler({key: p(self.imt)})
26+
mina = np.min(self.imt)
27+
maxa = np.max(self.imt)
28+
norm = (self.imt - mina) / (maxa - mina)
29+
expected = (norm * (2.0 - 1.0)) + 1.0
30+
assert_allclose(result[key], expected)
3031

3132
def test_factor_scale(self):
32-
key = "img"
33-
scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)
34-
result = scaler({key: self.imt})
35-
expected = (self.imt * (1 + 0.1)).astype(np.float32)
36-
np.testing.assert_allclose(result[key], expected)
33+
for p in TEST_NDARRAYS:
34+
key = "img"
35+
scaler = ScaleIntensityd(keys=[key], minv=None, maxv=None, factor=0.1)
36+
result = scaler({key: p(self.imt)})
37+
expected = (self.imt * (1 + 0.1)).astype(np.float32)
38+
assert_allclose(result[key], expected)
3739

3840

3941
if __name__ == "__main__":

0 commit comments

Comments
 (0)