Skip to content

Commit 80be1c3

Browse files
authored
Track applied operations in image filter (#7395)
Fixes #7394 ### Description When ImageFilter is in the transformation sequence it didn't pass the applied_operations. Now it is passed when present. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: axel.vlaminck <[email protected]>
1 parent 78295c7 commit 80be1c3

File tree

2 files changed

+24
-3
lines changed

2 files changed

+24
-3
lines changed

monai/transforms/utility/array.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,17 +1562,22 @@ def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int |
15621562
self.filter_size = filter_size
15631563
self.additional_args_for_filter = kwargs
15641564

1565-
def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> NdarrayOrTensor:
1565+
def __call__(
1566+
self, img: NdarrayOrTensor, meta_dict: dict | None = None, applied_operations: list | None = None
1567+
) -> NdarrayOrTensor:
15661568
"""
15671569
Args:
15681570
img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]
15691571
meta_dict: An optional dictionary with metadata
1572+
applied_operations: An optional list of operations that have been applied to the data
15701573
15711574
Returns:
15721575
A MetaTensor with the same shape as `img` and identical metadata
15731576
"""
15741577
if isinstance(img, MetaTensor):
15751578
meta_dict = img.meta
1579+
applied_operations = img.applied_operations
1580+
15761581
img_, prev_type, device = convert_data_type(img, torch.Tensor)
15771582
ndim = img_.ndim - 1 # assumes channel first format
15781583

@@ -1582,8 +1587,8 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: dict | None = None) -> Ndarr
15821587
self.filter = ApplyFilter(self.filter)
15831588

15841589
img_ = self._apply_filter(img_)
1585-
if meta_dict:
1586-
img_ = MetaTensor(img_, meta=meta_dict)
1590+
if meta_dict is not None or applied_operations is not None:
1591+
img_ = MetaTensor(img_, meta=meta_dict, applied_operations=applied_operations)
15871592
else:
15881593
img_, *_ = convert_data_type(img_, prev_type, device)
15891594
return img_

tests/test_image_filter.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
from parameterized import parameterized
1919

20+
from monai.data.meta_tensor import MetaTensor
2021
from monai.networks.layers.simplelayers import GaussianFilter
2122
from monai.transforms import ImageFilter, ImageFilterd, RandImageFilter, RandImageFilterd
2223

@@ -115,6 +116,21 @@ def test_call_3d(self, filter_name):
115116
out_tensor = filter(SAMPLE_IMAGE_3D)
116117
self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_3D.shape[1:])
117118

119+
def test_pass_applied_operations(self):
120+
"Test that applied operations are passed through"
121+
applied_operations = ["op1", "op2"]
122+
image = MetaTensor(SAMPLE_IMAGE_2D, applied_operations=applied_operations)
123+
filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
124+
out_tensor = filter(image)
125+
self.assertEqual(out_tensor.applied_operations, applied_operations)
126+
127+
def test_pass_empty_metadata_dict(self):
128+
"Test that applied operations are passed through"
129+
image = MetaTensor(SAMPLE_IMAGE_2D, meta={})
130+
filter = ImageFilter(SUPPORTED_FILTERS[0], 3, **ADDITIONAL_ARGUMENTS)
131+
out_tensor = filter(image)
132+
self.assertTrue(isinstance(out_tensor, MetaTensor))
133+
118134

119135
class TestImageFilterDict(unittest.TestCase):
120136
@parameterized.expand(SUPPORTED_FILTERS)

0 commit comments

Comments
 (0)