Skip to content

Commit c99bd41

Browse files
authored
EnsureType, RemoveRepeatedChannel, SplitChannel, ToCupy, ToNumpy, ToPil, ToTensor, Transpose (#2850)
* backends -> backend Signed-off-by: Richard Brown <[email protected]> * code format Signed-off-by: Richard Brown <[email protected]> * code format2 Signed-off-by: Richard Brown <[email protected]> * AddChannel, AsChannelFirst, AsChannelLast, EnsureChannelFirst, Identity, RepeatChannel Signed-off-by: Richard Brown <[email protected]> * moveaxis backwards compatible Signed-off-by: Richard Brown <[email protected]> * code format Signed-off-by: Richard Brown <[email protected]> * EnsureType, RemoveRepeatedChannel, SplitChannel, ToCupy, ToNumpy, ToPil, ToTensor, Transpose Signed-off-by: Richard Brown <[email protected]> * trigger ci/cd Signed-off-by: Richard Brown <[email protected]> * permute requires positive indices Signed-off-by: Richard Brown <[email protected]> * correct permute Signed-off-by: Richard Brown <[email protected]> * correct permute Signed-off-by: Richard Brown <[email protected]> * has_pil Signed-off-by: Richard Brown <[email protected]>
1 parent fe6ac0a commit c99bd41

18 files changed

+317
-213
lines changed

monai/transforms/utility/array.py

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,7 @@
3232
map_classes_to_indices,
3333
)
3434
from monai.transforms.utils_pytorch_numpy_unification import moveaxis
35-
from monai.utils import (
36-
convert_to_numpy,
37-
convert_to_tensor,
38-
ensure_tuple,
39-
issequenceiterable,
40-
look_up_option,
41-
min_version,
42-
optional_import,
43-
)
35+
from monai.utils import convert_to_numpy, convert_to_tensor, ensure_tuple, look_up_option, min_version, optional_import
4436
from monai.utils.enums import TransformBackends
4537
from monai.utils.type_conversion import convert_data_type
4638

@@ -255,20 +247,22 @@ class RemoveRepeatedChannel(Transform):
255247
repeats: the number of repetitions to be deleted for each element.
256248
"""
257249

250+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
251+
258252
def __init__(self, repeats: int) -> None:
259253
if repeats <= 0:
260254
raise AssertionError("repeats count must be greater than 0.")
261255

262256
self.repeats = repeats
263257

264-
def __call__(self, img: np.ndarray) -> np.ndarray:
258+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
265259
"""
266260
Apply the transform to `img`, assuming `img` is a "channel-first" array.
267261
"""
268-
if np.shape(img)[0] < 2:
262+
if img.shape[0] < 2:
269263
raise AssertionError("Image must have more than one channel")
270264

271-
return np.array(img[:: self.repeats, :])
265+
return img[:: self.repeats, :]
272266

273267

274268
class SplitChannel(Transform):
@@ -281,10 +275,12 @@ class SplitChannel(Transform):
281275
282276
"""
283277

278+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
279+
284280
def __init__(self, channel_dim: int = 0) -> None:
285281
self.channel_dim = channel_dim
286282

287-
def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> List[Union[np.ndarray, torch.Tensor]]:
283+
def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]:
288284
n_classes = img.shape[self.channel_dim]
289285
if n_classes <= 1:
290286
raise RuntimeError("input image does not contain multiple channels.")
@@ -335,18 +331,13 @@ class ToTensor(Transform):
335331
Converts the input image to a tensor without applying any other transformations.
336332
"""
337333

338-
def __call__(self, img) -> torch.Tensor:
334+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
335+
336+
def __call__(self, img: NdarrayOrTensor) -> torch.Tensor:
339337
"""
340338
Apply the transform to `img` and make it contiguous.
341339
"""
342-
if isinstance(img, torch.Tensor):
343-
return img.contiguous()
344-
if issequenceiterable(img):
345-
# numpy array with 0 dims is also sequence iterable
346-
if not (isinstance(img, np.ndarray) and img.ndim == 0):
347-
# `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims
348-
img = np.ascontiguousarray(img)
349-
return torch.as_tensor(img)
340+
return convert_to_tensor(img, wrap_sequence=True) # type: ignore
350341

351342

352343
class EnsureType(Transform):
@@ -361,14 +352,16 @@ class EnsureType(Transform):
361352
362353
"""
363354

355+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
356+
364357
def __init__(self, data_type: str = "tensor") -> None:
365358
data_type = data_type.lower()
366359
if data_type not in ("tensor", "numpy"):
367360
raise ValueError("`data type` must be 'tensor' or 'numpy'.")
368361

369362
self.data_type = data_type
370363

371-
def __call__(self, data):
364+
def __call__(self, data: NdarrayOrTensor) -> NdarrayOrTensor:
372365
"""
373366
Args:
374367
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
@@ -377,46 +370,46 @@ def __call__(self, data):
377370
if applicable.
378371
379372
"""
380-
return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data)
373+
return convert_to_tensor(data) if self.data_type == "tensor" else convert_to_numpy(data) # type: ignore
381374

382375

383376
class ToNumpy(Transform):
384377
"""
385378
Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor.
386379
"""
387380

388-
def __call__(self, img) -> np.ndarray:
381+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
382+
383+
def __call__(self, img: NdarrayOrTensor) -> np.ndarray:
389384
"""
390385
Apply the transform to `img` and make it contiguous.
391386
"""
392-
if isinstance(img, torch.Tensor):
393-
img = img.detach().cpu().numpy()
394-
elif has_cp and isinstance(img, cp_ndarray):
395-
img = cp.asnumpy(img)
396-
397-
array: np.ndarray = np.asarray(img)
398-
return np.ascontiguousarray(array) if array.ndim > 0 else array
387+
return convert_to_numpy(img) # type: ignore
399388

400389

401390
class ToCupy(Transform):
402391
"""
403392
Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor.
404393
"""
405394

406-
def __call__(self, img):
395+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
396+
397+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
407398
"""
408399
Apply the transform to `img` and make it contiguous.
409400
"""
410401
if isinstance(img, torch.Tensor):
411402
img = img.detach().cpu().numpy()
412-
return cp.ascontiguousarray(cp.asarray(img))
403+
return cp.ascontiguousarray(cp.asarray(img)) # type: ignore
413404

414405

415406
class ToPIL(Transform):
416407
"""
417408
Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image
418409
"""
419410

411+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
412+
420413
def __call__(self, img):
421414
"""
422415
Apply the transform to `img`.
@@ -433,13 +426,17 @@ class Transpose(Transform):
433426
Transposes the input image based on the given `indices` dimension ordering.
434427
"""
435428

429+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
430+
436431
def __init__(self, indices: Optional[Sequence[int]]) -> None:
437432
self.indices = None if indices is None else tuple(indices)
438433

439-
def __call__(self, img: np.ndarray) -> np.ndarray:
434+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
440435
"""
441436
Apply the transform to `img`.
442437
"""
438+
if isinstance(img, torch.Tensor):
439+
return img.permute(self.indices or tuple(range(img.ndim)[::-1]))
443440
return img.transpose(self.indices) # type: ignore
444441

445442

monai/transforms/utility/dictionary.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ class RemoveRepeatedChanneld(MapTransform):
334334
Dictionary-based wrapper of :py:class:`monai.transforms.RemoveRepeatedChannel`.
335335
"""
336336

337+
backend = RemoveRepeatedChannel.backend
338+
337339
def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool = False) -> None:
338340
"""
339341
Args:
@@ -345,7 +347,7 @@ def __init__(self, keys: KeysCollection, repeats: int, allow_missing_keys: bool
345347
super().__init__(keys, allow_missing_keys)
346348
self.repeater = RemoveRepeatedChannel(repeats)
347349

348-
def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
350+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
349351
d = dict(data)
350352
for key in self.key_iterator(d):
351353
d[key] = self.repeater(d[key])
@@ -356,9 +358,10 @@ class SplitChanneld(MapTransform):
356358
"""
357359
Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`.
358360
All the input specified by `keys` should be split into same count of data.
359-
360361
"""
361362

363+
backend = SplitChannel.backend
364+
362365
def __init__(
363366
self,
364367
keys: KeysCollection,
@@ -382,9 +385,7 @@ def __init__(
382385
self.output_postfixes = output_postfixes
383386
self.splitter = SplitChannel(channel_dim=channel_dim)
384387

385-
def __call__(
386-
self, data: Mapping[Hashable, Union[np.ndarray, torch.Tensor]]
387-
) -> Dict[Hashable, Union[np.ndarray, torch.Tensor]]:
388+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
388389
d = dict(data)
389390
for key in self.key_iterator(d):
390391
rets = self.splitter(d[key])
@@ -439,6 +440,8 @@ class ToTensord(MapTransform, InvertibleTransform):
439440
Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`.
440441
"""
441442

443+
backend = ToTensor.backend
444+
442445
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
443446
"""
444447
Args:
@@ -449,14 +452,14 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
449452
super().__init__(keys, allow_missing_keys)
450453
self.converter = ToTensor()
451454

452-
def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
455+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
453456
d = dict(data)
454457
for key in self.key_iterator(d):
455458
self.push_transform(d, key)
456459
d[key] = self.converter(d[key])
457460
return d
458461

459-
def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
462+
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
460463
d = deepcopy(dict(data))
461464
for key in self.key_iterator(d):
462465
# Create inverse transform
@@ -481,6 +484,8 @@ class EnsureTyped(MapTransform, InvertibleTransform):
481484
482485
"""
483486

487+
backend = EnsureType.backend
488+
484489
def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missing_keys: bool = False) -> None:
485490
"""
486491
Args:
@@ -492,7 +497,7 @@ def __init__(self, keys: KeysCollection, data_type: str = "tensor", allow_missin
492497
super().__init__(keys, allow_missing_keys)
493498
self.converter = EnsureType(data_type=data_type)
494499

495-
def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
500+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
496501
d = dict(data)
497502
for key in self.key_iterator(d):
498503
self.push_transform(d, key)
@@ -515,6 +520,8 @@ class ToNumpyd(MapTransform):
515520
Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.
516521
"""
517522

523+
backend = ToNumpy.backend
524+
518525
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
519526
"""
520527
Args:
@@ -537,6 +544,8 @@ class ToCupyd(MapTransform):
537544
Dictionary-based wrapper of :py:class:`monai.transforms.ToCupy`.
538545
"""
539546

547+
backend = ToCupy.backend
548+
540549
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
541550
"""
542551
Args:
@@ -547,7 +556,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
547556
super().__init__(keys, allow_missing_keys)
548557
self.converter = ToCupy()
549558

550-
def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
559+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
551560
d = dict(data)
552561
for key in self.key_iterator(d):
553562
d[key] = self.converter(d[key])
@@ -559,6 +568,8 @@ class ToPILd(MapTransform):
559568
Dictionary-based wrapper of :py:class:`monai.transforms.ToNumpy`.
560569
"""
561570

571+
backend = ToPIL.backend
572+
562573
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
563574
"""
564575
Args:
@@ -581,13 +592,15 @@ class Transposed(MapTransform, InvertibleTransform):
581592
Dictionary-based wrapper of :py:class:`monai.transforms.Transpose`.
582593
"""
583594

595+
backend = Transpose.backend
596+
584597
def __init__(
585598
self, keys: KeysCollection, indices: Optional[Sequence[int]], allow_missing_keys: bool = False
586599
) -> None:
587600
super().__init__(keys, allow_missing_keys)
588601
self.transform = Transpose(indices)
589602

590-
def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
603+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
591604
d = dict(data)
592605
for key in self.key_iterator(d):
593606
d[key] = self.transform(d[key])

monai/utils/type_conversion.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def get_dtype(data: Any):
8383
return type(data)
8484

8585

86-
def convert_to_tensor(data):
86+
def convert_to_tensor(data, wrap_sequence: bool = False):
8787
"""
8888
Utility to convert the input data to a PyTorch Tensor. If passing a dictionary, list or tuple,
8989
recursively check every item and convert it to PyTorch Tensor.
@@ -92,6 +92,8 @@ def convert_to_tensor(data):
9292
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
9393
will convert Tensor, Numpy array, float, int, bool to Tensors, strings and objects keep the original.
9494
for dictionary, list or tuple, convert every item to a Tensor if applicable.
95+
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`.
96+
If `True`, then `[1, 2]` -> `tensor([1, 2])`.
9597
9698
"""
9799
if isinstance(data, torch.Tensor):
@@ -105,17 +107,19 @@ def convert_to_tensor(data):
105107
return torch.as_tensor(data if data.ndim == 0 else np.ascontiguousarray(data))
106108
elif isinstance(data, (float, int, bool)):
107109
return torch.as_tensor(data)
108-
elif isinstance(data, dict):
109-
return {k: convert_to_tensor(v) for k, v in data.items()}
110+
elif isinstance(data, Sequence) and wrap_sequence:
111+
return torch.as_tensor(data)
110112
elif isinstance(data, list):
111113
return [convert_to_tensor(i) for i in data]
112114
elif isinstance(data, tuple):
113115
return tuple(convert_to_tensor(i) for i in data)
116+
elif isinstance(data, dict):
117+
return {k: convert_to_tensor(v) for k, v in data.items()}
114118

115119
return data
116120

117121

118-
def convert_to_numpy(data):
122+
def convert_to_numpy(data, wrap_sequence: bool = False):
119123
"""
120124
Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple,
121125
recursively check every item and convert it to numpy array.
@@ -124,20 +128,23 @@ def convert_to_numpy(data):
124128
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
125129
will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original.
126130
for dictionary, list or tuple, convert every item to a numpy array if applicable.
127-
131+
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
132+
If `True`, then `[1, 2]` -> `array([1, 2])`.
128133
"""
129134
if isinstance(data, torch.Tensor):
130135
data = data.detach().cpu().numpy()
131136
elif has_cp and isinstance(data, cp_ndarray):
132137
data = cp.asnumpy(data)
133138
elif isinstance(data, (float, int, bool)):
134139
data = np.asarray(data)
135-
elif isinstance(data, dict):
136-
return {k: convert_to_numpy(v) for k, v in data.items()}
140+
elif isinstance(data, Sequence) and wrap_sequence:
141+
return np.asarray(data)
137142
elif isinstance(data, list):
138143
return [convert_to_numpy(i) for i in data]
139144
elif isinstance(data, tuple):
140145
return tuple(convert_to_numpy(i) for i in data)
146+
elif isinstance(data, dict):
147+
return {k: convert_to_numpy(v) for k, v in data.items()}
141148

142149
if isinstance(data, np.ndarray) and data.ndim > 0:
143150
data = np.ascontiguousarray(data)

0 commit comments

Comments
 (0)