Skip to content

Commit acf4a9f

Browse files
authored
backends -> backend (#2838)
* backends -> backend Signed-off-by: Richard Brown <[email protected]> * code format Signed-off-by: Richard Brown <[email protected]> * code format2 Signed-off-by: Richard Brown <[email protected]>
1 parent fa5bc15 commit acf4a9f

File tree

5 files changed

+76
-41
lines changed

5 files changed

+76
-41
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@
500500
get_extreme_points,
501501
get_largest_connected_component_mask,
502502
get_number_image_type_conversions,
503+
get_transform_backends,
503504
img_bounds,
504505
in_bounds,
505506
is_empty,

monai/transforms/intensity/array.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class RandRicianNoise(RandomizableTransform):
131131
uniformly from 0 to std.
132132
"""
133133

134-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
134+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
135135

136136
def __init__(
137137
self,
@@ -197,7 +197,7 @@ class ShiftIntensity(Transform):
197197
offset: offset value to shift the intensity of image.
198198
"""
199199

200-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
200+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
201201

202202
def __init__(self, offset: float) -> None:
203203
self.offset = offset
@@ -219,7 +219,7 @@ class RandShiftIntensity(RandomizableTransform):
219219
Randomly shift intensity with randomly picked offset.
220220
"""
221221

222-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
222+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
223223

224224
def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None:
225225
"""
@@ -273,7 +273,7 @@ class StdShiftIntensity(Transform):
273273
dtype: output data type, defaults to float32.
274274
"""
275275

276-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
276+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
277277

278278
def __init__(
279279
self, factor: float, nonzero: bool = False, channel_wise: bool = False, dtype: DtypeLike = np.float32
@@ -318,7 +318,7 @@ class RandStdShiftIntensity(RandomizableTransform):
318318
by: ``v = v + factor * std(v)`` where the `factor` is randomly picked.
319319
"""
320320

321-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
321+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
322322

323323
def __init__(
324324
self,

monai/transforms/utility/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ class CastToType(Transform):
291291
specified PyTorch data type.
292292
"""
293293

294-
backends = [TransformBackends.TORCH, TransformBackends.NUMPY]
294+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
295295

296296
def __init__(self, dtype=np.float32) -> None:
297297
"""

monai/transforms/utils.py

Lines changed: 63 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
min_version,
3939
optional_import,
4040
)
41+
from monai.utils.enums import TransformBackends
4142
from monai.utils.type_conversion import convert_data_type
4243

4344
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
@@ -81,6 +82,7 @@
8182
"zero_margins",
8283
"equalize_hist",
8384
"get_number_image_type_conversions",
85+
"get_transform_backends",
8486
"print_transform_backends",
8587
]
8688

@@ -1158,22 +1160,17 @@ def _get_data(obj, key):
11581160
return num_conversions
11591161

11601162

1161-
def print_transform_backends():
1162-
"""Prints a list of backends of all MONAI transforms."""
1163-
1164-
class Colours:
1165-
red = "91"
1166-
green = "92"
1167-
yellow = "93"
1168-
1169-
def print_colour(t, colour):
1170-
print(f"\033[{colour}m{t}\033[00m")
1163+
def get_transform_backends():
1164+
"""Get the backends of all MONAI transforms.
11711165
1172-
tr_total = 0
1173-
tr_t_or_np = 0
1174-
tr_t = 0
1175-
tr_np = 0
1176-
tr_uncategorised = 0
1166+
Returns:
1167+
Dictionary, where each key is a transform, and its
1168+
corresponding values are a boolean list, stating
1169+
whether that transform supports (1) `torch.Tensor`,
1170+
and (2) `np.ndarray` as input without needing to
1171+
convert.
1172+
"""
1173+
backends = {}
11771174
unique_transforms = []
11781175
for n, obj in getmembers(monai.transforms):
11791176
# skip aliases
@@ -1194,21 +1191,54 @@ def print_colour(t, colour):
11941191
"InverteD",
11951192
]:
11961193
continue
1197-
tr_total += 1
1198-
if obj.backend == ["torch", "numpy"]:
1199-
tr_t_or_np += 1
1200-
print_colour(f"TorchOrNumpy: {n}", Colours.green)
1201-
elif obj.backend == ["torch"]:
1202-
tr_t += 1
1203-
print_colour(f"Torch: {n}", Colours.green)
1204-
elif obj.backend == ["numpy"]:
1205-
tr_np += 1
1206-
print_colour(f"Numpy: {n}", Colours.yellow)
1207-
else:
1208-
tr_uncategorised += 1
1209-
print_colour(f"Uncategorised: {n}", Colours.red)
1210-
print("Total number of transforms:", tr_total)
1211-
print_colour(f"Number transforms allowing both torch and numpy: {tr_t_or_np}", Colours.green)
1212-
print_colour(f"Number of TorchTransform: {tr_t}", Colours.green)
1213-
print_colour(f"Number of NumpyTransform: {tr_np}", Colours.yellow)
1214-
print_colour(f"Number of uncategorised: {tr_uncategorised}", Colours.red)
1194+
1195+
backends[n] = [
1196+
TransformBackends.TORCH in obj.backend,
1197+
TransformBackends.NUMPY in obj.backend,
1198+
]
1199+
return backends
1200+
1201+
1202+
def print_transform_backends():
1203+
"""Prints a list of backends of all MONAI transforms."""
1204+
1205+
class Colors:
1206+
none = ""
1207+
red = "91"
1208+
green = "92"
1209+
yellow = "93"
1210+
1211+
def print_color(t, color):
1212+
print(f"\033[{color}m{t}\033[00m")
1213+
1214+
def print_table_column(name, torch, numpy, color=Colors.none):
1215+
print_color("{:<50} {:<8} {:<8}".format(name, torch, numpy), color)
1216+
1217+
backends = get_transform_backends()
1218+
n_total = len(backends)
1219+
n_t_or_np, n_t, n_np, n_uncategorized = 0, 0, 0, 0
1220+
print_table_column("Transform", "Torch?", "Numpy?")
1221+
for k, v in backends.items():
1222+
if all(v):
1223+
color = Colors.green
1224+
n_t_or_np += 1
1225+
elif v[0]:
1226+
color = Colors.green
1227+
n_t += 1
1228+
elif v[1]:
1229+
color = Colors.yellow
1230+
n_np += 1
1231+
else:
1232+
color = Colors.red
1233+
n_uncategorized += 1
1234+
print_table_column(k, *v, color)
1235+
1236+
print("Total number of transforms:", n_total)
1237+
print_color(f"Number transforms allowing both torch and numpy: {n_t_or_np}", Colors.green)
1238+
print_color(f"Number of TorchTransform: {n_t}", Colors.green)
1239+
print_color(f"Number of NumpyTransform: {n_np}", Colors.yellow)
1240+
print_color(f"Number of uncategorised: {n_uncategorized}", Colors.red)
1241+
1242+
1243+
if __name__ == "__main__":
1244+
print_transform_backends()

tests/test_print_transform_backends.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111

1212
import unittest
1313

14-
from monai.transforms.utils import print_transform_backends
14+
from monai.transforms.utils import get_transform_backends, print_transform_backends
1515

1616

1717
class TestPrintTransformBackends(unittest.TestCase):
1818
def test_get_number_of_conversions(self):
19+
tr_t_or_np, *_ = get_transform_backends()
20+
self.assertGreater(len(tr_t_or_np), 0)
1921
print_transform_backends()
2022

2123

2224
if __name__ == "__main__":
23-
unittest.main()
25+
# unittest.main()
26+
a = TestPrintTransformBackends()
27+
a.test_get_number_of_conversions()

0 commit comments

Comments
 (0)