Skip to content

Commit fe6ac0a

Browse files
authored
AddChannel, AsChannelFirst, AsChannelLast, EnsureChannelFirst, Identity, RepeatChannel (#2840)
1 parent aa5fa1d commit fe6ac0a

15 files changed

+180
-89
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -518,3 +518,4 @@
518518
weighted_patch_samples,
519519
zero_margins,
520520
)
521+
from .utils_pytorch_numpy_unification import moveaxis

monai/transforms/utility/array.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
map_binary_to_indices,
3232
map_classes_to_indices,
3333
)
34+
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
3435
from monai.utils import (
3536
convert_to_numpy,
3637
convert_to_tensor,
@@ -82,17 +83,18 @@
8283

8384
class Identity(Transform):
8485
"""
85-
Convert the input to an np.ndarray, if input data is np.ndarray or subclasses, return unchanged data.
86+
Do nothing to the data.
8687
As the output value is same as input, it can be used as a testing tool to verify the transform chain,
8788
Compose or transform adaptor, etc.
88-
8989
"""
9090

91-
def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> np.ndarray:
91+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
92+
93+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
9294
"""
9395
Apply the transform to `img`.
9496
"""
95-
return np.asanyarray(img)
97+
return img
9698

9799

98100
class AsChannelFirst(Transform):
@@ -111,16 +113,18 @@ class AsChannelFirst(Transform):
111113
channel_dim: which dimension of input image is the channel, default is the last dimension.
112114
"""
113115

116+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
117+
114118
def __init__(self, channel_dim: int = -1) -> None:
115119
if not (isinstance(channel_dim, int) and channel_dim >= -1):
116120
raise AssertionError("invalid channel dimension.")
117121
self.channel_dim = channel_dim
118122

119-
def __call__(self, img: np.ndarray) -> np.ndarray:
123+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
120124
"""
121125
Apply the transform to `img`.
122126
"""
123-
return np.moveaxis(img, self.channel_dim, 0)
127+
return moveaxis(img, self.channel_dim, 0)
124128

125129

126130
class AsChannelLast(Transform):
@@ -138,16 +142,18 @@ class AsChannelLast(Transform):
138142
channel_dim: which dimension of input image is the channel, default is the first dimension.
139143
"""
140144

145+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
146+
141147
def __init__(self, channel_dim: int = 0) -> None:
142148
if not (isinstance(channel_dim, int) and channel_dim >= -1):
143149
raise AssertionError("invalid channel dimension.")
144150
self.channel_dim = channel_dim
145151

146-
def __call__(self, img: np.ndarray) -> np.ndarray:
152+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
147153
"""
148154
Apply the transform to `img`.
149155
"""
150-
return np.moveaxis(img, self.channel_dim, -1)
156+
return moveaxis(img, self.channel_dim, -1)
151157

152158

153159
class AddChannel(Transform):
@@ -164,7 +170,9 @@ class AddChannel(Transform):
164170
transforms.
165171
"""
166172

167-
def __call__(self, img: NdarrayTensor):
173+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
174+
175+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
168176
"""
169177
Apply the transform to `img`.
170178
"""
@@ -179,14 +187,16 @@ class EnsureChannelFirst(Transform):
179187
Convert the data to `channel_first` based on the `original_channel_dim` information.
180188
"""
181189

190+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
191+
182192
def __init__(self, strict_check: bool = True):
183193
"""
184194
Args:
185195
strict_check: whether to raise an error when the meta information is insufficient.
186196
"""
187197
self.strict_check = strict_check
188198

189-
def __call__(self, img: np.ndarray, meta_dict: Optional[Mapping] = None):
199+
def __call__(self, img: NdarrayOrTensor, meta_dict: Optional[Mapping] = None) -> NdarrayOrTensor:
190200
"""
191201
Apply the transform to `img`.
192202
"""
@@ -220,16 +230,19 @@ class RepeatChannel(Transform):
220230
repeats: the number of repetitions for each element.
221231
"""
222232

233+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
234+
223235
def __init__(self, repeats: int) -> None:
224236
if repeats <= 0:
225237
raise AssertionError("repeats count must be greater than 0.")
226238
self.repeats = repeats
227239

228-
def __call__(self, img: np.ndarray) -> np.ndarray:
240+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
229241
"""
230242
Apply the transform to `img`, assuming `img` is a "channel-first" array.
231243
"""
232-
return np.repeat(img, self.repeats, 0)
244+
repeeat_fn = torch.repeat_interleave if isinstance(img, torch.Tensor) else np.repeat
245+
return repeeat_fn(img, self.repeats, 0) # type: ignore
233246

234247

235248
class RemoveRepeatedChannel(Transform):

monai/transforms/utility/dictionary.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ class Identityd(MapTransform):
169169
Dictionary-based wrapper of :py:class:`monai.transforms.Identity`.
170170
"""
171171

172+
backend = Identity.backend
173+
172174
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
173175
"""
174176
Args:
@@ -180,9 +182,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
180182
super().__init__(keys, allow_missing_keys)
181183
self.identity = Identity()
182184

183-
def __call__(
184-
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
185-
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
185+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
186186
d = dict(data)
187187
for key in self.key_iterator(d):
188188
d[key] = self.identity(d[key])
@@ -194,6 +194,8 @@ class AsChannelFirstd(MapTransform):
194194
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelFirst`.
195195
"""
196196

197+
backend = AsChannelFirst.backend
198+
197199
def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None:
198200
"""
199201
Args:
@@ -205,7 +207,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_ke
205207
super().__init__(keys, allow_missing_keys)
206208
self.converter = AsChannelFirst(channel_dim=channel_dim)
207209

208-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
210+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
209211
d = dict(data)
210212
for key in self.key_iterator(d):
211213
d[key] = self.converter(d[key])
@@ -217,6 +219,8 @@ class AsChannelLastd(MapTransform):
217219
Dictionary-based wrapper of :py:class:`monai.transforms.AsChannelLast`.
218220
"""
219221

222+
backend = AsChannelLast.backend
223+
220224
def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None:
221225
"""
222226
Args:
@@ -228,7 +232,7 @@ def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_key
228232
super().__init__(keys, allow_missing_keys)
229233
self.converter = AsChannelLast(channel_dim=channel_dim)
230234

231-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
235+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
232236
d = dict(data)
233237
for key in self.key_iterator(d):
234238
d[key] = self.converter(d[key])
@@ -240,6 +244,8 @@ class AddChanneld(MapTransform):
240244
Dictionary-based wrapper of :py:class:`monai.transforms.AddChannel`.
241245
"""
242246

247+
backend = AddChannel.backend
248+
243249
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
244250
"""
245251
Args:
@@ -250,7 +256,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
250256
super().__init__(keys, allow_missing_keys)
251257
self.adder = AddChannel()
252258

253-
def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]:
259+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
254260
d = dict(data)
255261
for key in self.key_iterator(d):
256262
d[key] = self.adder(d[key])
@@ -262,6 +268,8 @@ class EnsureChannelFirstd(MapTransform):
262268
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureChannelFirst`.
263269
"""
264270

271+
backend = EnsureChannelFirst.backend
272+
265273
def __init__(
266274
self,
267275
keys: KeysCollection,
@@ -289,7 +297,7 @@ def __init__(
289297
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
290298
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
291299

292-
def __call__(self, data) -> Dict[Hashable, np.ndarray]:
300+
def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]:
293301
d = dict(data)
294302
for key, meta_key, meta_key_postfix in zip(self.keys, self.meta_keys, self.meta_key_postfix):
295303
d[key] = self.adjuster(d[key], d[meta_key or f"{key}_{meta_key_postfix}"])
@@ -301,6 +309,8 @@ class RepeatChanneld(MapTransform):
301309
Dictionary-based wrapper of :py:class:`monai.transforms.RepeatChannel`.
302310
"""
303311

312+
backend = RepeatChannel.backend
313+
304314
def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:
305315
"""
306316
Args:
@@ -312,7 +322,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool
312322
super().__init__(keys, allow_missing_keys)
313323
self.repeater = RepeatChannel(repeats)
314324

315-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
325+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
316326
d = dict(data)
317327
for key in self.key_iterator(d):
318328
d[key] = self.repeater(d[key])
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import numpy as np
13+
import torch
14+
15+
from monai.config.type_definitions import NdarrayOrTensor
16+
17+
__all__ = [
18+
"moveaxis",
19+
]
20+
21+
22+
def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor:
23+
if isinstance(x, torch.Tensor):
24+
if hasattr(torch, "moveaxis"):
25+
return torch.moveaxis(x, src, dst)
26+
# moveaxis only available in pytorch as of 1.8.0
27+
else:
28+
# get original indices
29+
indices = list(range(x.ndim))
30+
# make src and dst positive
31+
if src < 0:
32+
src = len(indices) + src
33+
if dst < 0:
34+
dst = len(indices) + dst
35+
# remove desired index and insert it in new position
36+
indices.pop(src)
37+
indices.insert(dst, src)
38+
return x.permute(indices)
39+
elif isinstance(x, np.ndarray):
40+
return np.moveaxis(x, src, dst)
41+
raise RuntimeError()

tests/test_add_channeld.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,21 @@
1515
from parameterized import parameterized
1616

1717
from monai.transforms import AddChanneld
18+
from tests.utils import TEST_NDARRAYS
1819

19-
TEST_CASE_1 = [
20-
{"keys": ["img", "seg"]},
21-
{"img": np.array([[0, 1], [1, 2]]), "seg": np.array([[0, 1], [1, 2]])},
22-
(1, 2, 2),
23-
]
20+
TESTS = []
21+
for p in TEST_NDARRAYS:
22+
TESTS.append(
23+
[
24+
{"keys": ["img", "seg"]},
25+
{"img": p(np.array([[0, 1], [1, 2]])), "seg": p(np.array([[0, 1], [1, 2]]))},
26+
(1, 2, 2),
27+
]
28+
)
2429

2530

2631
class TestAddChanneld(unittest.TestCase):
27-
@parameterized.expand([TEST_CASE_1])
32+
@parameterized.expand(TESTS)
2833
def test_shape(self, input_param, input_data, expected_shape):
2934
result = AddChanneld(**input_param)(input_data)
3035
self.assertEqual(result["img"].shape, expected_shape)

tests/test_as_channel_first.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,29 @@
1212
import unittest
1313

1414
import numpy as np
15+
import torch
1516
from parameterized import parameterized
1617

1718
from monai.transforms import AsChannelFirst
19+
from tests.utils import TEST_NDARRAYS, assert_allclose
1820

19-
TEST_CASE_1 = [{"channel_dim": -1}, (4, 1, 2, 3)]
20-
21-
TEST_CASE_2 = [{"channel_dim": 3}, (4, 1, 2, 3)]
22-
23-
TEST_CASE_3 = [{"channel_dim": 2}, (3, 1, 2, 4)]
21+
TESTS = []
22+
for p in TEST_NDARRAYS:
23+
TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)])
24+
TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)])
25+
TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)])
2426

2527

2628
class TestAsChannelFirst(unittest.TestCase):
27-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
28-
def test_shape(self, input_param, expected_shape):
29-
test_data = np.random.randint(0, 2, size=[1, 2, 3, 4])
29+
@parameterized.expand(TESTS)
30+
def test_value(self, in_type, input_param, expected_shape):
31+
test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))
3032
result = AsChannelFirst(**input_param)(test_data)
3133
self.assertTupleEqual(result.shape, expected_shape)
34+
if isinstance(test_data, torch.Tensor):
35+
test_data = test_data.cpu().numpy()
36+
expected = np.moveaxis(test_data, input_param["channel_dim"], 0)
37+
assert_allclose(expected, result)
3238

3339

3440
if __name__ == "__main__":

tests/test_as_channel_firstd.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,22 @@
1515
from parameterized import parameterized
1616

1717
from monai.transforms import AsChannelFirstd
18+
from tests.utils import TEST_NDARRAYS
1819

19-
TEST_CASE_1 = [{"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]
20-
21-
TEST_CASE_2 = [{"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]
22-
23-
TEST_CASE_3 = [{"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]
20+
TESTS = []
21+
for p in TEST_NDARRAYS:
22+
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)])
23+
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)])
24+
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)])
2425

2526

2627
class TestAsChannelFirstd(unittest.TestCase):
27-
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
28-
def test_shape(self, input_param, expected_shape):
28+
@parameterized.expand(TESTS)
29+
def test_shape(self, in_type, input_param, expected_shape):
2930
test_data = {
30-
"image": np.random.randint(0, 2, size=[1, 2, 3, 4]),
31-
"label": np.random.randint(0, 2, size=[1, 2, 3, 4]),
32-
"extra": np.random.randint(0, 2, size=[1, 2, 3, 4]),
31+
"image": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
32+
"label": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
33+
"extra": in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])),
3334
}
3435
result = AsChannelFirstd(**input_param)(test_data)
3536
self.assertTupleEqual(result["image"].shape, expected_shape)

0 commit comments

Comments
 (0)