23
23
from monai .config import DtypeLike
24
24
from monai .data .utils import get_random_patch , get_valid_patch_size
25
25
from monai .networks .layers import GaussianFilter , HilbertTransform , SavitzkyGolayFilter
26
- from monai .transforms .transform import RandomizableTransform , Transform
26
+ from monai .transforms .transform import Fourier , RandomizableTransform , Transform
27
27
from monai .transforms .utils import rescale_array
28
28
from monai .utils import (
29
29
PT_BEFORE_1_7 ,
@@ -1196,23 +1196,25 @@ def _randomize(self, _: Any) -> None:
1196
1196
self .sampled_alpha = self .R .uniform (self .alpha [0 ], self .alpha [1 ])
1197
1197
1198
1198
1199
- class GibbsNoise (Transform ):
1199
+ class GibbsNoise (Transform , Fourier ):
1200
1200
"""
1201
1201
The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
1202
1202
are one of the common type of type artifacts appearing in MRI scans.
1203
1203
1204
1204
The transform is applied to all the channels in the data.
1205
1205
1206
1206
For general information on Gibbs artifacts, please refer to:
1207
- https://pubs.rsna.org/doi/full/10.1148/rg.313105115
1208
- https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949
1209
1207
1208
+ `An Image-based Approach to Understanding the Physics of MR Artifacts
1209
+ <https://pubs.rsna.org/doi/full/10.1148/rg.313105115>`_.
1210
+
1211
+ `The AAPM/RSNA Physics Tutorial for Residents
1212
+ <https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949>`_
1210
1213
1211
1214
Args:
1212
- alpha (float) : Parametrizes the intensity of the Gibbs noise filter applied. Takes
1215
+ alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes
1213
1216
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
1214
- as_tensor_output: if true return torch.Tensor, else return np.array. default: True.
1215
-
1217
+ as_tensor_output: if true return torch.Tensor, else return np.array. Default: True.
1216
1218
"""
1217
1219
1218
1220
def __init__ (self , alpha : float = 0.5 , as_tensor_output : bool = True ) -> None :
@@ -1221,47 +1223,22 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:
1221
1223
raise AssertionError ("alpha must take values in the interval [0,1]." )
1222
1224
self .alpha = alpha
1223
1225
self .as_tensor_output = as_tensor_output
1224
- self ._device = torch .device ("cpu" )
1225
1226
1226
1227
def __call__ (self , img : Union [np .ndarray , torch .Tensor ]) -> Union [torch .Tensor , np .ndarray ]:
1227
1228
n_dims = len (img .shape [1 :])
1228
1229
1229
- # convert to ndarray to work with np.fft
1230
- _device = None
1231
- if isinstance (img , torch .Tensor ):
1232
- _device = img .device
1233
- img = img .cpu ().detach ().numpy ()
1234
-
1230
+ if isinstance (img , np .ndarray ):
1231
+ img = torch .Tensor (img )
1235
1232
# FT
1236
- k = self ._shift_fourier (img , n_dims )
1233
+ k = self .shift_fourier (img , n_dims )
1237
1234
# build and apply mask
1238
1235
k = self ._apply_mask (k )
1239
1236
# map back
1240
- img = self ._inv_shift_fourier (k , n_dims )
1241
- return torch .Tensor (img ).to (_device or self ._device ) if self .as_tensor_output else img
1242
-
1243
- def _shift_fourier (self , x : Union [np .ndarray , torch .Tensor ], n_dims : int ) -> np .ndarray :
1244
- """
1245
- Applies fourier transform and shifts its output.
1246
- Only the spatial dimensions get transformed.
1237
+ img = self .inv_shift_fourier (k , n_dims )
1247
1238
1248
- Args:
1249
- x (np.ndarray): tensor to fourier transform.
1250
- """
1251
- out : np .ndarray = np .fft .fftshift (np .fft .fftn (x , axes = tuple (range (- n_dims , 0 ))), axes = tuple (range (- n_dims , 0 )))
1252
- return out
1239
+ return img if self .as_tensor_output else img .cpu ().detach ().numpy ()
1253
1240
1254
- def _inv_shift_fourier (self , k : Union [np .ndarray , torch .Tensor ], n_dims : int ) -> np .ndarray :
1255
- """
1256
- Applies inverse shift and fourier transform. Only the spatial
1257
- dimensions are transformed.
1258
- """
1259
- out : np .ndarray = np .fft .ifftn (
1260
- np .fft .ifftshift (k , axes = tuple (range (- n_dims , 0 ))), axes = tuple (range (- n_dims , 0 ))
1261
- ).real
1262
- return out
1263
-
1264
- def _apply_mask (self , k : np .ndarray ) -> np .ndarray :
1241
+ def _apply_mask (self , k : torch .Tensor ) -> torch .Tensor :
1265
1242
"""Builds and applies a mask on the spatial dimensions.
1266
1243
1267
1244
Args:
@@ -1287,11 +1264,11 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray:
1287
1264
mask = np .repeat (mask [None ], k .shape [0 ], axis = 0 )
1288
1265
1289
1266
# apply binary mask
1290
- k_masked : np . ndarray = k * mask
1267
+ k_masked = k * torch . tensor ( mask , device = k . device )
1291
1268
return k_masked
1292
1269
1293
1270
1294
- class KSpaceSpikeNoise (Transform ):
1271
+ class KSpaceSpikeNoise (Transform , Fourier ):
1295
1272
"""
1296
1273
Apply localized spikes in `k`-space at the given locations and intensities.
1297
1274
Spike (Herringbone) artifact is a type of data acquisition artifact which
@@ -1354,7 +1331,7 @@ def __init__(
1354
1331
def __call__ (self , img : Union [np .ndarray , torch .Tensor ]) -> Union [torch .Tensor , np .ndarray ]:
1355
1332
"""
1356
1333
Args:
1357
- img (np.array or torch.tensor) : image with dimensions (C, H, W) or (C, H, W, D)
1334
+ img: image with dimensions (C, H, W) or (C, H, W, D)
1358
1335
"""
1359
1336
# checking that tuples in loc are consistent with img size
1360
1337
self ._check_indices (img )
@@ -1368,22 +1345,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
1368
1345
1369
1346
n_dims = len (img .shape [1 :])
1370
1347
1371
- # convert to ndarray to work with np.fft
1372
- if isinstance (img , torch .Tensor ):
1373
- device = img .device
1374
- img = img .cpu ().detach ().numpy ()
1375
- else :
1376
- device = torch .device ("cpu" )
1377
-
1348
+ if isinstance (img , np .ndarray ):
1349
+ img = torch .Tensor (img )
1378
1350
# FT
1379
- k = self ._shift_fourier (img , n_dims )
1380
- log_abs = np .log (np .absolute (k ) + 1e-10 )
1381
- phase = np .angle (k )
1351
+ k = self .shift_fourier (img , n_dims )
1352
+ log_abs = torch .log (torch .absolute (k ) + 1e-10 )
1353
+ phase = torch .angle (k )
1382
1354
1383
1355
k_intensity = self .k_intensity
1384
1356
# default log intensity
1385
1357
if k_intensity is None :
1386
- k_intensity = tuple (np .mean (log_abs , axis = tuple (range (- n_dims , 0 ))) * 2.5 )
1358
+ k_intensity = tuple (torch .mean (log_abs , dim = tuple (range (- n_dims , 0 ))) * 2.5 )
1387
1359
1388
1360
# highlight
1389
1361
if isinstance (self .loc [0 ], Sequence ):
@@ -1392,9 +1364,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
1392
1364
else :
1393
1365
self ._set_spike (log_abs , self .loc , k_intensity )
1394
1366
# map back
1395
- k = np .exp (log_abs ) * np .exp (1j * phase )
1396
- img = self ._inv_shift_fourier (k , n_dims )
1397
- return torch .Tensor (img , device = device ) if self .as_tensor_output else img
1367
+ k = torch .exp (log_abs ) * torch .exp (1j * phase )
1368
+ img = self .inv_shift_fourier (k , n_dims )
1369
+
1370
+ return img if self .as_tensor_output else img .cpu ().detach ().numpy ()
1398
1371
1399
1372
def _check_indices (self , img ) -> None :
1400
1373
"""Helper method to check consistency of self.loc and input image.
@@ -1414,48 +1387,27 @@ def _check_indices(self, img) -> None:
1414
1387
f"The index value at position { i } of one of the tuples in loc = { self .loc } is out of bounds for current image."
1415
1388
)
1416
1389
1417
- def _set_spike (self , k : np . ndarray , idx : Tuple , val : Union [Sequence [float ], float ]):
1390
+ def _set_spike (self , k : torch . Tensor , idx : Tuple , val : Union [Sequence [float ], float ]):
1418
1391
"""
1419
1392
Helper function to introduce a given intensity at given location.
1420
1393
1421
1394
Args:
1422
- k (np.array) : intensity array to alter.
1423
- idx (tuple) : index of location where to apply change.
1424
- val (float) : value of intensity to write in.
1395
+ k: intensity array to alter.
1396
+ idx: index of location where to apply change.
1397
+ val: value of intensity to write in.
1425
1398
"""
1426
1399
if len (k .shape ) == len (idx ):
1427
1400
if isinstance (val , Sequence ):
1428
1401
k [idx ] = val [idx [0 ]]
1429
1402
else :
1430
1403
k [idx ] = val
1431
1404
elif len (k .shape ) == 4 and len (idx ) == 3 :
1432
- k [:, idx [0 ], idx [1 ], idx [2 ]] = val
1405
+ k [:, idx [0 ], idx [1 ], idx [2 ]] = val # type: ignore
1433
1406
elif len (k .shape ) == 3 and len (idx ) == 2 :
1434
- k [:, idx [0 ], idx [1 ]] = val
1435
-
1436
- def _shift_fourier (self , x : Union [np .ndarray , torch .Tensor ], n_dims : int ) -> np .ndarray :
1437
- """
1438
- Applies fourier transform and shifts its output.
1439
- Only the spatial dimensions get transformed.
1440
-
1441
- Args:
1442
- x (np.ndarray): tensor to fourier transform.
1443
- """
1444
- out : np .ndarray = np .fft .fftshift (np .fft .fftn (x , axes = tuple (range (- n_dims , 0 ))), axes = tuple (range (- n_dims , 0 )))
1445
- return out
1407
+ k [:, idx [0 ], idx [1 ]] = val # type: ignore
1446
1408
1447
- def _inv_shift_fourier (self , k : Union [np .ndarray , torch .Tensor ], n_dims : int ) -> np .ndarray :
1448
- """
1449
- Applies inverse shift and fourier transform. Only the spatial
1450
- dimensions are transformed.
1451
- """
1452
- out : np .ndarray = np .fft .ifftn (
1453
- np .fft .ifftshift (k , axes = tuple (range (- n_dims , 0 ))), axes = tuple (range (- n_dims , 0 ))
1454
- ).real
1455
- return out
1456
1409
1457
-
1458
- class RandKSpaceSpikeNoise (RandomizableTransform ):
1410
+ class RandKSpaceSpikeNoise (RandomizableTransform , Fourier ):
1459
1411
"""
1460
1412
Naturalistic data augmentation via spike artifacts. The transform applies
1461
1413
localized spikes in `k`-space, and it is the random version of
@@ -1476,7 +1428,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform):
1476
1428
channels at once, or channel-wise if ``channel_wise = True``.
1477
1429
intensity_range: pass a tuple
1478
1430
(a, b) to sample the log-intensity from the interval (a, b)
1479
- uniformly for all channels. Or pass sequence of intervals
1431
+ uniformly for all channels. Or pass sequence of intevals
1480
1432
((a0, b0), (a1, b1), ...) to sample for each respective channel.
1481
1433
In the second case, the number of 2-tuples must match the number of
1482
1434
channels.
@@ -1521,7 +1473,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
1521
1473
Apply transform to `img`. Assumes data is in channel-first form.
1522
1474
1523
1475
Args:
1524
- img (np.array or torch.tensor) : image with dimensions (C, H, W) or (C, H, W, D)
1476
+ img: image with dimensions (C, H, W) or (C, H, W, D)
1525
1477
"""
1526
1478
if self .intensity_range is not None :
1527
1479
if isinstance (self .intensity_range [0 ], Sequence ) and len (self .intensity_range ) != img .shape [0 ]:
@@ -1532,19 +1484,20 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
1532
1484
self .sampled_k_intensity = []
1533
1485
self .sampled_locs = []
1534
1486
1535
- # convert to ndarray to work with np.fft
1536
- x , device = self ._to_numpy (img )
1537
- intensity_range = self ._make_sequence (x )
1538
- self ._randomize (x , intensity_range )
1487
+ if not isinstance (img , torch .Tensor ):
1488
+ img = torch .Tensor (img )
1489
+
1490
+ intensity_range = self ._make_sequence (img )
1491
+ self ._randomize (img , intensity_range )
1539
1492
1540
- # build/apply transform only if there are spike locations
1493
+ # build/appy transform only if there are spike locations
1541
1494
if self .sampled_locs :
1542
1495
transform = KSpaceSpikeNoise (self .sampled_locs , self .sampled_k_intensity , self .as_tensor_output )
1543
- return transform (x )
1496
+ return transform (img )
1544
1497
1545
- return torch . Tensor ( x , device = device ) if self .as_tensor_output else x
1498
+ return img if self .as_tensor_output else img . detach (). numpy ()
1546
1499
1547
- def _randomize (self , img : np . ndarray , intensity_range : Sequence [Sequence [float ]]) -> None :
1500
+ def _randomize (self , img : torch . Tensor , intensity_range : Sequence [Sequence [float ]]) -> None :
1548
1501
"""
1549
1502
Helper method to sample both the location and intensity of the spikes.
1550
1503
When not working channel wise (channel_wise=False) it use the random
@@ -1568,11 +1521,11 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]
1568
1521
spatial = tuple (self .R .randint (0 , k ) for k in img .shape [1 :])
1569
1522
self .sampled_locs = [(i ,) + spatial for i in range (img .shape [0 ])]
1570
1523
if isinstance (intensity_range [0 ], Sequence ):
1571
- self .sampled_k_intensity = [self .R .uniform (* p ) for p in intensity_range ] # type: ignore
1524
+ self .sampled_k_intensity = [self .R .uniform (p [ 0 ], p [ 1 ] ) for p in intensity_range ]
1572
1525
else :
1573
- self .sampled_k_intensity = [self .R .uniform (* self . intensity_range )] * len (img ) # type: ignore
1526
+ self .sampled_k_intensity = [self .R .uniform (intensity_range [ 0 ], intensity_range [ 1 ] )] * len (img ) # type: ignore
1574
1527
1575
- def _make_sequence (self , x : np . ndarray ) -> Sequence [Sequence [float ]]:
1528
+ def _make_sequence (self , x : torch . Tensor ) -> Sequence [Sequence [float ]]:
1576
1529
"""
1577
1530
Formats the sequence of intensities ranges to Sequence[Sequence[float]].
1578
1531
"""
@@ -1586,27 +1539,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]:
1586
1539
# set default range if one not provided
1587
1540
return self ._set_default_range (x )
1588
1541
1589
- def _set_default_range (self , x : np . ndarray ) -> Sequence [Sequence [float ]]:
1542
+ def _set_default_range (self , img : torch . Tensor ) -> Sequence [Sequence [float ]]:
1590
1543
"""
1591
1544
Sets default intensity ranges to be sampled.
1592
1545
1593
1546
Args:
1594
- x (np.ndarray): tensor to fourier transform.
1547
+ img: image to transform.
1595
1548
"""
1596
- n_dims = len (x .shape [1 :])
1549
+ n_dims = len (img .shape [1 :])
1597
1550
1598
- k = np . fft . fftshift ( np . fft . fftn ( x , axes = tuple ( range ( - n_dims , 0 ))), axes = tuple ( range ( - n_dims , 0 )) )
1599
- log_abs = np .log (np .absolute (k ) + 1e-10 )
1600
- shifted_means = np .mean (log_abs , axis = tuple (range (- n_dims , 0 ))) * 2.5
1551
+ k = self . shift_fourier ( img , n_dims )
1552
+ log_abs = torch .log (torch .absolute (k ) + 1e-10 )
1553
+ shifted_means = torch .mean (log_abs , dim = tuple (range (- n_dims , 0 ))) * 2.5
1601
1554
intensity_sequence = tuple ((i * 0.95 , i * 1.1 ) for i in shifted_means )
1602
1555
return intensity_sequence
1603
1556
1604
- def _to_numpy (self , img : Union [np .ndarray , torch .Tensor ]) -> Tuple [np .ndarray , torch .device ]:
1605
- if isinstance (img , torch .Tensor ):
1606
- return img .cpu ().detach ().numpy (), img .device
1607
- else :
1608
- return img , torch .device ("cpu" )
1609
-
1610
1557
1611
1558
class RandCoarseDropout (RandomizableTransform ):
1612
1559
"""
0 commit comments