Skip to content

Updated RandKSpaceSpikeNoised. Collected Fourier mappings. #2665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ Generic Interfaces
.. autoclass:: Decollated
:members:

`Fourier`
^^^^^^^^^^^^^
.. autoclass:: Fourier
:members:

Vanilla Transforms
------------------
Expand Down
10 changes: 9 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,15 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
from .transform import (
Fourier,
MapTransform,
Randomizable,
RandomizableTransform,
ThreadUnsafe,
Transform,
apply_transform,
)
from .utility.array import (
AddChannel,
AddExtremePointsChannel,
Expand Down
163 changes: 55 additions & 108 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.config import DtypeLike
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
from monai.transforms.transform import Fourier, RandomizableTransform, Transform
from monai.transforms.utils import rescale_array
from monai.utils import (
PT_BEFORE_1_7,
Expand Down Expand Up @@ -1196,23 +1196,25 @@ def _randomize(self, _: Any) -> None:
self.sampled_alpha = self.R.uniform(self.alpha[0], self.alpha[1])


class GibbsNoise(Transform):
class GibbsNoise(Transform, Fourier):
"""
The transform applies Gibbs noise to 2D/3D MRI images. Gibbs artifacts
are one of the common type of type artifacts appearing in MRI scans.

The transform is applied to all the channels in the data.

For general information on Gibbs artifacts, please refer to:
https://pubs.rsna.org/doi/full/10.1148/rg.313105115
https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949

`An Image-based Approach to Understanding the Physics of MR Artifacts
<https://pubs.rsna.org/doi/full/10.1148/rg.313105115>`_.

`The AAPM/RSNA Physics Tutorial for Residents
<https://pubs.rsna.org/doi/full/10.1148/radiographics.22.4.g02jl14949>`_

Args:
alpha (float): Parametrizes the intensity of the Gibbs noise filter applied. Takes
alpha: Parametrizes the intensity of the Gibbs noise filter applied. Takes
values in the interval [0,1] with alpha = 0 acting as the identity mapping.
as_tensor_output: if true return torch.Tensor, else return np.array. default: True.

as_tensor_output: if true return torch.Tensor, else return np.array. Default: True.
"""

def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:
Expand All @@ -1221,47 +1223,22 @@ def __init__(self, alpha: float = 0.5, as_tensor_output: bool = True) -> None:
raise AssertionError("alpha must take values in the interval [0,1].")
self.alpha = alpha
self.as_tensor_output = as_tensor_output
self._device = torch.device("cpu")

def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
n_dims = len(img.shape[1:])

# convert to ndarray to work with np.fft
_device = None
if isinstance(img, torch.Tensor):
_device = img.device
img = img.cpu().detach().numpy()

if isinstance(img, np.ndarray):
img = torch.Tensor(img)
# FT
k = self._shift_fourier(img, n_dims)
k = self.shift_fourier(img, n_dims)
# build and apply mask
k = self._apply_mask(k)
# map back
img = self._inv_shift_fourier(k, n_dims)
return torch.Tensor(img).to(_device or self._device) if self.as_tensor_output else img

def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
"""
Applies fourier transform and shifts its output.
Only the spatial dimensions get transformed.
img = self.inv_shift_fourier(k, n_dims)

Args:
x (np.ndarray): tensor to fourier transform.
"""
out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
return out
return img if self.as_tensor_output else img.cpu().detach().numpy()

def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
"""
Applies inverse shift and fourier transform. Only the spatial
dimensions are transformed.
"""
out: np.ndarray = np.fft.ifftn(
np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))
).real
return out

def _apply_mask(self, k: np.ndarray) -> np.ndarray:
def _apply_mask(self, k: torch.Tensor) -> torch.Tensor:
"""Builds and applies a mask on the spatial dimensions.

Args:
Expand All @@ -1287,11 +1264,11 @@ def _apply_mask(self, k: np.ndarray) -> np.ndarray:
mask = np.repeat(mask[None], k.shape[0], axis=0)

# apply binary mask
k_masked: np.ndarray = k * mask
k_masked = k * torch.tensor(mask, device=k.device)
return k_masked


class KSpaceSpikeNoise(Transform):
class KSpaceSpikeNoise(Transform, Fourier):
"""
Apply localized spikes in `k`-space at the given locations and intensities.
Spike (Herringbone) artifact is a type of data acquisition artifact which
Expand Down Expand Up @@ -1354,7 +1331,7 @@ def __init__(
def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor, np.ndarray]:
"""
Args:
img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D)
img: image with dimensions (C, H, W) or (C, H, W, D)
"""
# checking that tuples in loc are consistent with img size
self._check_indices(img)
Expand All @@ -1368,22 +1345,17 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,

n_dims = len(img.shape[1:])

# convert to ndarray to work with np.fft
if isinstance(img, torch.Tensor):
device = img.device
img = img.cpu().detach().numpy()
else:
device = torch.device("cpu")

if isinstance(img, np.ndarray):
img = torch.Tensor(img)
# FT
k = self._shift_fourier(img, n_dims)
log_abs = np.log(np.absolute(k) + 1e-10)
phase = np.angle(k)
k = self.shift_fourier(img, n_dims)
log_abs = torch.log(torch.absolute(k) + 1e-10)
phase = torch.angle(k)

k_intensity = self.k_intensity
# default log intensity
if k_intensity is None:
k_intensity = tuple(np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5)
k_intensity = tuple(torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5)

# highlight
if isinstance(self.loc[0], Sequence):
Expand All @@ -1392,9 +1364,10 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
else:
self._set_spike(log_abs, self.loc, k_intensity)
# map back
k = np.exp(log_abs) * np.exp(1j * phase)
img = self._inv_shift_fourier(k, n_dims)
return torch.Tensor(img, device=device) if self.as_tensor_output else img
k = torch.exp(log_abs) * torch.exp(1j * phase)
img = self.inv_shift_fourier(k, n_dims)

return img if self.as_tensor_output else img.cpu().detach().numpy()

def _check_indices(self, img) -> None:
"""Helper method to check consistency of self.loc and input image.
Expand All @@ -1414,48 +1387,27 @@ def _check_indices(self, img) -> None:
f"The index value at position {i} of one of the tuples in loc = {self.loc} is out of bounds for current image."
)

def _set_spike(self, k: np.ndarray, idx: Tuple, val: Union[Sequence[float], float]):
def _set_spike(self, k: torch.Tensor, idx: Tuple, val: Union[Sequence[float], float]):
"""
Helper function to introduce a given intensity at given location.

Args:
k (np.array): intensity array to alter.
idx (tuple): index of location where to apply change.
val (float): value of intensity to write in.
k: intensity array to alter.
idx: index of location where to apply change.
val: value of intensity to write in.
"""
if len(k.shape) == len(idx):
if isinstance(val, Sequence):
k[idx] = val[idx[0]]
else:
k[idx] = val
elif len(k.shape) == 4 and len(idx) == 3:
k[:, idx[0], idx[1], idx[2]] = val
k[:, idx[0], idx[1], idx[2]] = val # type: ignore
elif len(k.shape) == 3 and len(idx) == 2:
k[:, idx[0], idx[1]] = val

def _shift_fourier(self, x: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
"""
Applies fourier transform and shifts its output.
Only the spatial dimensions get transformed.

Args:
x (np.ndarray): tensor to fourier transform.
"""
out: np.ndarray = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
return out
k[:, idx[0], idx[1]] = val # type: ignore

def _inv_shift_fourier(self, k: Union[np.ndarray, torch.Tensor], n_dims: int) -> np.ndarray:
"""
Applies inverse shift and fourier transform. Only the spatial
dimensions are transformed.
"""
out: np.ndarray = np.fft.ifftn(
np.fft.ifftshift(k, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0))
).real
return out


class RandKSpaceSpikeNoise(RandomizableTransform):
class RandKSpaceSpikeNoise(RandomizableTransform, Fourier):
"""
Naturalistic data augmentation via spike artifacts. The transform applies
localized spikes in `k`-space, and it is the random version of
Expand All @@ -1476,7 +1428,7 @@ class RandKSpaceSpikeNoise(RandomizableTransform):
channels at once, or channel-wise if ``channel_wise = True``.
intensity_range: pass a tuple
(a, b) to sample the log-intensity from the interval (a, b)
uniformly for all channels. Or pass sequence of intervals
uniformly for all channels. Or pass sequence of intevals
((a0, b0), (a1, b1), ...) to sample for each respective channel.
In the second case, the number of 2-tuples must match the number of
channels.
Expand Down Expand Up @@ -1521,7 +1473,7 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
Apply transform to `img`. Assumes data is in channel-first form.

Args:
img (np.array or torch.tensor): image with dimensions (C, H, W) or (C, H, W, D)
img: image with dimensions (C, H, W) or (C, H, W, D)
"""
if self.intensity_range is not None:
if isinstance(self.intensity_range[0], Sequence) and len(self.intensity_range) != img.shape[0]:
Expand All @@ -1532,19 +1484,20 @@ def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[torch.Tensor,
self.sampled_k_intensity = []
self.sampled_locs = []

# convert to ndarray to work with np.fft
x, device = self._to_numpy(img)
intensity_range = self._make_sequence(x)
self._randomize(x, intensity_range)
if not isinstance(img, torch.Tensor):
img = torch.Tensor(img)

intensity_range = self._make_sequence(img)
self._randomize(img, intensity_range)

# build/apply transform only if there are spike locations
# build/appy transform only if there are spike locations
if self.sampled_locs:
transform = KSpaceSpikeNoise(self.sampled_locs, self.sampled_k_intensity, self.as_tensor_output)
return transform(x)
return transform(img)

return torch.Tensor(x, device=device) if self.as_tensor_output else x
return img if self.as_tensor_output else img.detach().numpy()

def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]) -> None:
def _randomize(self, img: torch.Tensor, intensity_range: Sequence[Sequence[float]]) -> None:
"""
Helper method to sample both the location and intensity of the spikes.
When not working channel wise (channel_wise=False) it use the random
Expand All @@ -1568,11 +1521,11 @@ def _randomize(self, img: np.ndarray, intensity_range: Sequence[Sequence[float]]
spatial = tuple(self.R.randint(0, k) for k in img.shape[1:])
self.sampled_locs = [(i,) + spatial for i in range(img.shape[0])]
if isinstance(intensity_range[0], Sequence):
self.sampled_k_intensity = [self.R.uniform(*p) for p in intensity_range] # type: ignore
self.sampled_k_intensity = [self.R.uniform(p[0], p[1]) for p in intensity_range]
else:
self.sampled_k_intensity = [self.R.uniform(*self.intensity_range)] * len(img) # type: ignore
self.sampled_k_intensity = [self.R.uniform(intensity_range[0], intensity_range[1])] * len(img) # type: ignore

def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]:
def _make_sequence(self, x: torch.Tensor) -> Sequence[Sequence[float]]:
"""
Formats the sequence of intensities ranges to Sequence[Sequence[float]].
"""
Expand All @@ -1586,27 +1539,21 @@ def _make_sequence(self, x: np.ndarray) -> Sequence[Sequence[float]]:
# set default range if one not provided
return self._set_default_range(x)

def _set_default_range(self, x: np.ndarray) -> Sequence[Sequence[float]]:
def _set_default_range(self, img: torch.Tensor) -> Sequence[Sequence[float]]:
"""
Sets default intensity ranges to be sampled.

Args:
x (np.ndarray): tensor to fourier transform.
img: image to transform.
"""
n_dims = len(x.shape[1:])
n_dims = len(img.shape[1:])

k = np.fft.fftshift(np.fft.fftn(x, axes=tuple(range(-n_dims, 0))), axes=tuple(range(-n_dims, 0)))
log_abs = np.log(np.absolute(k) + 1e-10)
shifted_means = np.mean(log_abs, axis=tuple(range(-n_dims, 0))) * 2.5
k = self.shift_fourier(img, n_dims)
log_abs = torch.log(torch.absolute(k) + 1e-10)
shifted_means = torch.mean(log_abs, dim=tuple(range(-n_dims, 0))) * 2.5
intensity_sequence = tuple((i * 0.95, i * 1.1) for i in shifted_means)
return intensity_sequence

def _to_numpy(self, img: Union[np.ndarray, torch.Tensor]) -> Tuple[np.ndarray, torch.device]:
if isinstance(img, torch.Tensor):
return img.cpu().detach().numpy(), img.device
else:
return img, torch.device("cpu")


class RandCoarseDropout(RandomizableTransform):
"""
Expand Down
Loading