Skip to content

Commit befb5f6

Browse files
bwittmannKumoLiu
andauthored
Forced Fourier class to output contiguous tensors. (#7969)
Forced `Fourier` class to output contiguous tensors, which potentially fixes a performance bottleneck. ### Description Some transforms, such as `RandKSpaceSpikeNoise`, rely on the `Fourier` class. In its current state, the `Fourier` class returns non-contiguous tensors, which potentially limits performance. For example, when followed by `RandHistogramShift`, the following warning occurs: ``` <path_to_monai>/monai/transforms/intensity/array.py:1852: UserWarning: torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible. This message will only appear once per program. (Triggered internally at /opt/conda/conda-bld/pytorch_1716905975447/work/aten/src/ATen/native/BucketizationUtils.h:32.) indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1 ``` A straightforward fix is to force the `Fourier` class to output contiguous tensors (see commit). To reproduce, please run: ``` from monai.transforms import RandKSpaceSpikeNoise from monai.transforms.utils import Fourier import numpy as np ### TEST WITH TRANSFORMS ### t = RandKSpaceSpikeNoise(prob=1) # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) a_torch_mod = t(a_torch) print(a_torch_mod.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) a_np_mod = t(a_np) # automatically transformed to torch.tensor print(a_np_mod.is_contiguous()) ### TEST DIRECTLY WITH FOURIER ### f = Fourier() # inv_shift_fourier # for torch tensors real_torch = torch.randn(1, 128, 128, 128) im_torch = torch.randn(1, 128, 128, 128) k_torch = torch.complex(real_torch, im_torch) print(k_torch.is_contiguous()) out_torch = f.inv_shift_fourier(k_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays real_np = np.random.randn(1, 100, 100, 100) im_np = np.random.randn(1, 100, 100, 100) k_np = real_np + 1j * im_np print(k_np.flags['C_CONTIGUOUS']) out_np = f.inv_shift_fourier(k_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) # shift_fourier # for torch tensors a_torch = torch.rand(1, 128, 128, 128) print(a_torch.is_contiguous()) out_torch = f.shift_fourier(a_torch, spatial_dims=3) print(out_torch.is_contiguous()) # for np arrays a_np = np.random.rand(1, 128, 128, 128) print(a_np.flags['C_CONTIGUOUS']) out_np = f.shift_fourier(a_np, spatial_dims=3) print(out_np.flags['C_CONTIGUOUS']) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Bastian Wittmann <[email protected]>. Signed-off-by: Bastian Wittmann <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent dbfe418 commit befb5f6

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

monai/transforms/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,14 +1863,15 @@ class Fourier:
18631863
"""
18641864

18651865
@staticmethod
1866-
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
1866+
def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:
18671867
"""
18681868
Applies fourier transform and shifts the zero-frequency component to the
18691869
center of the spectrum. Only the spatial dimensions get transformed.
18701870
18711871
Args:
18721872
x: Image to transform.
18731873
spatial_dims: Number of spatial dimensions.
1874+
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
18741875
18751876
Returns
18761877
k: K-space data.
@@ -1885,17 +1886,20 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
18851886
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
18861887
else:
18871888
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
1888-
return k
1889+
return ascontiguousarray(k) if as_contiguous else k
18891890

18901891
@staticmethod
1891-
def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor:
1892+
def inv_shift_fourier(
1893+
k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False
1894+
) -> NdarrayOrTensor:
18921895
"""
18931896
Applies inverse shift and fourier transform. Only the spatial
18941897
dimensions are transformed.
18951898
18961899
Args:
18971900
k: K-space data.
18981901
spatial_dims: Number of spatial dimensions.
1902+
as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
18991903
19001904
Returns:
19011905
x: Tensor in image space.
@@ -1910,7 +1914,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
19101914
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
19111915
else:
19121916
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
1913-
return out
1917+
return ascontiguousarray(out) if as_contiguous else out
19141918

19151919

19161920
def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:

0 commit comments

Comments
 (0)