You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Forced Fourier class to output contiguous tensors. (#7969)
Forced `Fourier` class to output contiguous tensors, which potentially
fixes a performance bottleneck.
### Description
Some transforms, such as `RandKSpaceSpikeNoise`, rely on the `Fourier`
class.
In its current state, the `Fourier` class returns non-contiguous
tensors, which potentially limits performance.
For example, when followed by `RandHistogramShift`, the following
warning occurs:
```
<path_to_monai>/monai/transforms/intensity/array.py:1852: UserWarning: torch.searchsorted(): input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible. This message will only appear once per program. (Triggered internally at /opt/conda/conda-bld/pytorch_1716905975447/work/aten/src/ATen/native/BucketizationUtils.h:32.)
indices = ns.searchsorted(xp.reshape(-1), x.reshape(-1)) - 1
```
A straightforward fix is to force the `Fourier` class to output
contiguous tensors (see commit).
To reproduce, please run:
```
from monai.transforms import RandKSpaceSpikeNoise
from monai.transforms.utils import Fourier
import numpy as np
### TEST WITH TRANSFORMS ###
t = RandKSpaceSpikeNoise(prob=1)
# for torch tensors
a_torch = torch.rand(1, 128, 128, 128)
print(a_torch.is_contiguous())
a_torch_mod = t(a_torch)
print(a_torch_mod.is_contiguous())
# for np arrays
a_np = np.random.rand(1, 128, 128, 128)
print(a_np.flags['C_CONTIGUOUS'])
a_np_mod = t(a_np) # automatically transformed to torch.tensor
print(a_np_mod.is_contiguous())
### TEST DIRECTLY WITH FOURIER ###
f = Fourier()
# inv_shift_fourier
# for torch tensors
real_torch = torch.randn(1, 128, 128, 128)
im_torch = torch.randn(1, 128, 128, 128)
k_torch = torch.complex(real_torch, im_torch)
print(k_torch.is_contiguous())
out_torch = f.inv_shift_fourier(k_torch, spatial_dims=3)
print(out_torch.is_contiguous())
# for np arrays
real_np = np.random.randn(1, 100, 100, 100)
im_np = np.random.randn(1, 100, 100, 100)
k_np = real_np + 1j * im_np
print(k_np.flags['C_CONTIGUOUS'])
out_np = f.inv_shift_fourier(k_np, spatial_dims=3)
print(out_np.flags['C_CONTIGUOUS'])
# shift_fourier
# for torch tensors
a_torch = torch.rand(1, 128, 128, 128)
print(a_torch.is_contiguous())
out_torch = f.shift_fourier(a_torch, spatial_dims=3)
print(out_torch.is_contiguous())
# for np arrays
a_np = np.random.rand(1, 128, 128, 128)
print(a_np.flags['C_CONTIGUOUS'])
out_np = f.shift_fourier(a_np, spatial_dims=3)
print(out_np.flags['C_CONTIGUOUS'])
```
### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.
---------
Signed-off-by: Bastian Wittmann <[email protected]>.
Signed-off-by: Bastian Wittmann <[email protected]>
Co-authored-by: YunLiu <[email protected]>
0 commit comments