Skip to content

Commit 4f7371a

Browse files
authored
Updated RandKSpaceSpikeNoised. Collected Fourier mappings. (#2665)
* Moved fourier functions to their own class. Modified RandKSpaceSpikeNoised. 1. Allow RandKSpaceSpikeNoised to work with arbitrary keys. 2. Introduced Fourier transform to keep the forward/backward fourier mappings. Signed-off-by: Yaniel Cabrera <[email protected]> * removed old code Signed-off-by: Yaniel Cabrera <[email protected]> * Ignore torch.fft tests if not present Ignore tests with versions of Pytorch which lack the module fft. Signed-off-by: Yaniel Cabrera <[email protected]> * update Signed-off-by: Yaniel Cabrera <[email protected]> * typing update Signed-off-by: Yaniel Cabrera <[email protected]> * added unit test for Fourier Signed-off-by: Yaniel Cabrera <[email protected]> * added unit test for Fourier Signed-off-by: Yaniel Cabrera <[email protected]> * fixing black Signed-off-by: Yaniel Cabrera <[email protected]>
1 parent 7f42efe commit 4f7371a

14 files changed

+249
-151
lines changed

docs/source/transforms.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ Generic Interfaces
5353
.. autoclass:: Decollated
5454
:members:
5555

56+
`Fourier`
57+
^^^^^^^^^^^^^
58+
.. autoclass:: Fourier
59+
:members:
5660

5761
Vanilla Transforms
5862
------------------

monai/transforms/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,15 @@
306306
ZoomD,
307307
ZoomDict,
308308
)
309-
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
309+
from .transform import (
310+
Fourier,
311+
MapTransform,
312+
Randomizable,
313+
RandomizableTransform,
314+
ThreadUnsafe,
315+
Transform,
316+
apply_transform,
317+
)
310318
from .utility.array import (
311319
AddChannel,
312320
AddExtremePointsChannel,

monai/transforms/intensity/array.py

Lines changed: 55 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from monai.config import DtypeLike
2424
from monai.data.utils import get_random_patch, get_valid_patch_size
2525
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
26-
from monai.transforms.transform import RandomizableTransform, Transform
26+
from monai.transforms.transform import Fourier, RandomizableTransform, Transform
2727
from monai.transforms.utils import rescale_array
2828
from monai.utils import (
2929
PT_BEFORE_1_7,
@@ -1196,23 +1196,25 @@ def _randomize(self, _: Any) -> None:
11961196
self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1])
11971197

11981198

1199-
class GibbsNoise(Transform):
1199+
class GibbsNoise(Transform, Fourier):
12001200
"""
12011201
The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
12021202
are one of the common type of type artifacts appearing in MRI scans.
12031203
12041204
The transform is applied to all the channels in the data.
12051205
12061206
For general information on Gibbs artifacts, please refer to:
1207-
https://pubs.rsna.org/doi/full/10.1148/rg.313105115
1208-
https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949
12091207
1208+
`An Image-based Approach to Understanding the Physics of MR Artifacts
1209+
<https://pubs.rsna.org/doi/full/10.1148/rg.313105115>`_.
1210+
1211+
`The AAPM/RSNA Physics Tutorial for Residents
1212+
<https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949>`_
12101213
12111214
Args:
1212-
alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes
1215+
alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes
12131216
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
1214-
as_tensor_output: if true return torch.Tensor, else return np.array. default: True.
1215-
1217+
as_tensor_output: if true return torch.Tensor, else return np.array. Default: True.
12161218
"""
12171219

12181220
def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:
@@ -1221,47 +1223,22 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:
12211223
raise AssertionError("alpha must take values in the interval [0,1].")
12221224
self.alpha = alpha
12231225
self.as_tensor_output = as_tensor_output
1224-
self._device = torch.device("cpu")
12251226

12261227
def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
12271228
n_dims = len(img.shape[1:])
12281229

1229-
# convert to ndarray to work with np.fft
1230-
_device = None
1231-
if isinstance(img, torch.Tensor):
1232-
_device = img.device
1233-
img = img.cpu().detach().numpy()
1234-
1230+
if isinstance(img, np.ndarray):
1231+
img = torch.Tensor(img)
12351232
# FT
1236-
k = self._shift_fourier(img, n_dims)
1233+
k = self.shift_fourier(img, n_dims)
12371234
# build and apply mask
12381235
k = self._apply_mask(k)
12391236
# map back
1240-
img = self._inv_shift_fourier(k, n_dims)
1241-
return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img
1242-
1243-
def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
1244-
"""
1245-
Applies fourier transform and shifts its output.
1246-
Only the spatial dimensions get transformed.
1237+
img = self.inv_shift_fourier(k, n_dims)
12471238

1248-
Args:
1249-
x (np.ndarray): tensor to fourier transform.
1250-
"""
1251-
out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
1252-
return out
1239+
return img if self.as_tensor_output else img.cpu().detach().numpy()
12531240

1254-
def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
1255-
"""
1256-
Applies inverse shift and fourier transform. Only the spatial
1257-
dimensions are transformed.
1258-
"""
1259-
out: np.ndarray = np.fft.ifftn(
1260-
np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))
1261-
).real
1262-
return out
1263-
1264-
def _apply_mask(self, k: np.ndarray) -> np.ndarray:
1241+
def _apply_mask(self, k: torch.Tensor) -> torch.Tensor:
12651242
"""Builds and applies a mask on the spatial dimensions.
12661243
12671244
Args:
@@ -1287,11 +1264,11 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray:
12871264
mask = np.repeat(mask[None], k.shape[0], axis=0)
12881265

12891266
# apply binary mask
1290-
k_masked: np.ndarray = k * mask
1267+
k_masked = k * torch.tensor(mask, device=k.device)
12911268
return k_masked
12921269

12931270

1294-
class KSpaceSpikeNoise(Transform):
1271+
class KSpaceSpikeNoise(Transform, Fourier):
12951272
"""
12961273
Apply localized spikes in `k`-space at the given locations and intensities.
12971274
Spike (Herringbone) artifact is a type of data acquisition artifact which
@@ -1354,7 +1331,7 @@ def __init__(
13541331
def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
13551332
"""
13561333
Args:
1357-
img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D)
1334+
img: image with dimensions (C, H, W) or (C, H, W, D)
13581335
"""
13591336
# checking that tuples in loc are consistent with img size
13601337
self._check_indices(img)
@@ -1368,22 +1345,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
13681345

13691346
n_dims = len(img.shape[1:])
13701347

1371-
# convert to ndarray to work with np.fft
1372-
if isinstance(img, torch.Tensor):
1373-
device = img.device
1374-
img = img.cpu().detach().numpy()
1375-
else:
1376-
device = torch.device("cpu")
1377-
1348+
if isinstance(img, np.ndarray):
1349+
img = torch.Tensor(img)
13781350
# FT
1379-
k = self._shift_fourier(img, n_dims)
1380-
log_abs = np.log(np.absolute(k) + 1e-10)
1381-
phase = np.angle(k)
1351+
k = self.shift_fourier(img, n_dims)
1352+
log_abs = torch.log(torch.absolute(k) + 1e-10)
1353+
phase = torch.angle(k)
13821354

13831355
k_intensity = self.k_intensity
13841356
# default log intensity
13851357
if k_intensity is None:
1386-
k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5)
1358+
k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5)
13871359

13881360
# highlight
13891361
if isinstance(self.loc[0], Sequence):
@@ -1392,9 +1364,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
13921364
else:
13931365
self._set_spike(log_abs, self.loc, k_intensity)
13941366
# map back
1395-
k = np.exp(log_abs) * np.exp(1j * phase)
1396-
img = self._inv_shift_fourier(k, n_dims)
1397-
return torch.Tensor(img, device=device) if self.as_tensor_output else img
1367+
k = torch.exp(log_abs) * torch.exp(1j * phase)
1368+
img = self.inv_shift_fourier(k, n_dims)
1369+
1370+
return img if self.as_tensor_output else img.cpu().detach().numpy()
13981371

13991372
def _check_indices(self, img) -> None:
14001373
"""Helper method to check consistency of self.loc and input image.
@@ -1414,48 +1387,27 @@ def _check_indices(self, img) -> None:
14141387
f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image."
14151388
)
14161389

1417-
def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]):
1390+
def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]):
14181391
"""
14191392
Helper function to introduce a given intensity at given location.
14201393
14211394
Args:
1422-
k (np.array): intensity array to alter.
1423-
idx (tuple): index of location where to apply change.
1424-
val (float): value of intensity to write in.
1395+
k: intensity array to alter.
1396+
idx: index of location where to apply change.
1397+
val: value of intensity to write in.
14251398
"""
14261399
if len(k.shape) == len(idx):
14271400
if isinstance(val, Sequence):
14281401
k[idx] = val[idx[0]]
14291402
else:
14301403
k[idx] = val
14311404
elif len(k.shape) == 4 and len(idx) == 3:
1432-
k[:, idx[0], idx[1], idx[2]] = val
1405+
k[:, idx[0], idx[1], idx[2]] = val # type: ignore
14331406
elif len(k.shape) == 3 and len(idx) == 2:
1434-
k[:, idx[0], idx[1]] = val
1435-
1436-
def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
1437-
"""
1438-
Applies fourier transform and shifts its output.
1439-
Only the spatial dimensions get transformed.
1440-
1441-
Args:
1442-
x (np.ndarray): tensor to fourier transform.
1443-
"""
1444-
out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
1445-
return out
1407+
k[:, idx[0], idx[1]] = val # type: ignore
14461408

1447-
def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
1448-
"""
1449-
Applies inverse shift and fourier transform. Only the spatial
1450-
dimensions are transformed.
1451-
"""
1452-
out: np.ndarray = np.fft.ifftn(
1453-
np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))
1454-
).real
1455-
return out
14561409

1457-
1458-
class RandKSpaceSpikeNoise(RandomizableTransform):
1410+
class RandKSpaceSpikeNoise(RandomizableTransform, Fourier):
14591411
"""
14601412
Naturalistic data augmentation via spike artifacts. The transform applies
14611413
localized spikes in `k`-space, and it is the random version of
@@ -1476,7 +1428,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform):
14761428
channels at once, or channel-wise if ``channel_wise = True``.
14771429
intensity_range: pass a tuple
14781430
(a, b) to sample the log-intensity from the interval (a, b)
1479-
uniformly for all channels. Or pass sequence of intervals
1431+
uniformly for all channels. Or pass sequence of intevals
14801432
((a0, b0), (a1, b1), ...) to sample for each respective channel.
14811433
In the second case, the number of 2-tuples must match the number of
14821434
channels.
@@ -1521,7 +1473,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
15211473
Apply transform to `img`. Assumes data is in channel-first form.
15221474
15231475
Args:
1524-
img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D)
1476+
img: image with dimensions (C, H, W) or (C, H, W, D)
15251477
"""
15261478
if self.intensity_range is not None:
15271479
if isinstance(self.intensity_range[0], Sequence) and len(self.intensity_range) != img.shape[0]:
@@ -1532,19 +1484,20 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
15321484
self.sampled_k_intensity = []
15331485
self.sampled_locs = []
15341486

1535-
# convert to ndarray to work with np.fft
1536-
x, device = self._to_numpy(img)
1537-
intensity_range = self._make_sequence(x)
1538-
self._randomize(x, intensity_range)
1487+
if not isinstance(img, torch.Tensor):
1488+
img = torch.Tensor(img)
1489+
1490+
intensity_range = self._make_sequence(img)
1491+
self._randomize(img, intensity_range)
15391492

1540-
# build/apply transform only if there are spike locations
1493+
# build/appy transform only if there are spike locations
15411494
if self.sampled_locs:
15421495
transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output)
1543-
return transform(x)
1496+
return transform(img)
15441497

1545-
return torch.Tensor(x, device=device) if self.as_tensor_output else x
1498+
return img if self.as_tensor_output else img.detach().numpy()
15461499

1547-
def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None:
1500+
def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None:
15481501
"""
15491502
Helper method to sample both the location and intensity of the spikes.
15501503
When not working channel wise (channel_wise=False) it use the random
@@ -1568,11 +1521,11 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]
15681521
spatial = tuple(self.R.randint(0, k) for k in img.shape[1:])
15691522
self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])]
15701523
if isinstance(intensity_range[0], Sequence):
1571-
self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] # type: ignore
1524+
self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range]
15721525
else:
1573-
self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore
1526+
self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore
15741527

1575-
def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]:
1528+
def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]:
15761529
"""
15771530
Formats the sequence of intensities ranges to Sequence[Sequence[float]].
15781531
"""
@@ -1586,27 +1539,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]:
15861539
# set default range if one not provided
15871540
return self._set_default_range(x)
15881541

1589-
def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]:
1542+
def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]:
15901543
"""
15911544
Sets default intensity ranges to be sampled.
15921545
15931546
Args:
1594-
x (np.ndarray): tensor to fourier transform.
1547+
img: image to transform.
15951548
"""
1596-
n_dims = len(x.shape[1:])
1549+
n_dims = len(img.shape[1:])
15971550

1598-
k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
1599-
log_abs = np.log(np.absolute(k) + 1e-10)
1600-
shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5
1551+
k = self.shift_fourier(img, n_dims)
1552+
log_abs = torch.log(torch.absolute(k) + 1e-10)
1553+
shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5
16011554
intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means)
16021555
return intensity_sequence
16031556

1604-
def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.device]:
1605-
if isinstance(img, torch.Tensor):
1606-
return img.cpu().detach().numpy(), img.device
1607-
else:
1608-
return img, torch.device("cpu")
1609-
16101557

16111558
class RandCoarseDropout(RandomizableTransform):
16121559
"""

0 commit comments

Comments
 (0)