Skip to content

1866 Add TransformInverter handler #1970

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Apr 9, 2021
5 changes: 5 additions & 0 deletions docs/source/handlers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,8 @@ GarbageCollector handler
------------------------
.. autoclass:: GarbageCollector
:members:

Transform inverter
------------------
.. autoclass:: TransformInverter
:members:
10 changes: 5 additions & 5 deletions monai/data/inverse_batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Any, Callable, Dict, Hashable, Optional, Sequence

import numpy as np
from torch.utils.data.dataloader import DataLoader as TorchDataLoader

from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.data.utils import decollate_batch, no_collation, pad_list_data_collate
from monai.transforms.croppad.batch import PadListDataCollate
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Transform
Expand All @@ -42,13 +43,12 @@ def _transform(self, index: int) -> Dict[Hashable, np.ndarray]:
if self.pad_collation_used:
data = PadListDataCollate.inverse(data)

if not isinstance(self.invertible_transform, InvertibleTransform):
warnings.warn("transform is not invertible, can't invert transform for the input data.")
return data
return self.invertible_transform.inverse(data)


def no_collation(x):
return x


class BatchInverseTransform(Transform):
"""Perform inverse on a batch of data. This is useful if you have inferred a batch of images and want to invert them all."""

Expand Down
8 changes: 8 additions & 0 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"sorted_dict",
"decollate_batch",
"pad_list_data_collate",
"no_collation",
]


Expand Down Expand Up @@ -379,6 +380,13 @@ def pad_list_data_collate(
return PadListDataCollate(method, mode)(batch)


def no_collation(x):
"""
No any collation operation.
"""
return x


def worker_init_fn(worker_id: int) -> None:
"""
Callback function for PyTorch DataLoader `worker_init_fn`.
Expand Down
1 change: 1 addition & 0 deletions monai/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
from .transform_inverter import TransformInverter
from .utils import (
evenly_divisible_all_gather,
stopping_fn_from_loss,
Expand Down
11 changes: 9 additions & 2 deletions monai/handlers/segmentation_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def __init__(
output_dtype=output_dtype,
squeeze_end_dims=squeeze_end_dims,
data_root_dir=data_root_dir,
save_batch=True,
)
self.batch_transform = batch_transform
self.output_transform = output_transform
Expand Down Expand Up @@ -147,5 +146,13 @@ def __call__(self, engine: Engine) -> None:
"""
meta_data = self.batch_transform(engine.state.batch)
engine_output = self.output_transform(engine.state.output)
self._saver(engine_output, meta_data)
if isinstance(engine_output, (tuple, list)):
# if a list of data in shape: [channel, H, W, [D]], save every item separately
self._saver.save_batch = False
for i, d in enumerate(engine_output):
self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
else:
# if the data is in shape: [batch, channel, H, W, [D]]
self._saver.save_batch = True
self._saver(engine_output, meta_data)
self.logger.info("saved all the model outputs into files.")
94 changes: 94 additions & 0 deletions monai/handlers/transform_inverter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Callable, Optional

from torch.utils.data import DataLoader as TorchDataLoader

from monai.data import BatchInverseTransform
from monai.data.utils import no_collation
from monai.engines.utils import CommonKeys
from monai.transforms import InvertibleTransform, allow_missing_keys_mode
from monai.utils import InverseKeys, exact_version, optional_import

Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
if TYPE_CHECKING:
from ignite.engine import Engine
else:
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")


class TransformInverter:
"""
Ignite handler to automatically invert all the pre-transforms that support `inverse`.
It takes `engine.state.output` as the input data and uses the transforms infomation from `engine.state.batch`.

Note:
This handler is experimental API in v0.5, the interpolation mode in the transforms
and inverse transforms are the same, so maybe it's not correct as we may want to use `bilinear`
for input image but use `nearest` when inverting transforms for model outout.
For this case, a solution is to set `batch_key` to the label field if we have labels.

"""

def __init__(
self,
transform: InvertibleTransform,
loader: TorchDataLoader,
collate_fn: Optional[Callable] = no_collation,
batch_key: str = CommonKeys.IMAGE,
output_key: str = CommonKeys.PRED,
postfix: str = "inverted",
) -> None:
"""
Args:
transform: a callable data transform on input data.
loader: data loader used to generate the batch of data.
collate_fn: how to collate data after inverse transformations.
default won't do any collation, so the output will be a list of size batch size.
batch_key: the key of input data in `ignite.engine.batch`. will get the applied transforms
for this input data, then invert them for the model output, default to "image".
output_key: the key of model output in `ignite.engine.output`, invert transforms on it.
postfix: will save the inverted result into `ignite.engine.output` with key `{ouput_key}_{postfix}`.

"""
self.transform = transform
self.inverter = BatchInverseTransform(transform=transform, loader=loader, collate_fn=collate_fn)
self.batch_key = batch_key
self.output_key = output_key
self.postfix = postfix

def attach(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
engine.add_event_handler(Events.ITERATION_COMPLETED, self)

def __call__(self, engine: Engine) -> None:
"""
Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
transform_key = self.batch_key + InverseKeys.KEY_SUFFIX
if transform_key not in engine.state.batch:
warnings.warn("all the pre-transforms are not InvertibleTransform or no need to invert.")
return

segs_dict = {
self.batch_key: engine.state.output[self.output_key].detach().cpu(),
transform_key: engine.state.batch[transform_key],
}

with allow_missing_keys_mode(self.transform): # type: ignore
inverted_key = f"{self.output_key}_{self.postfix}"
engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)]
17 changes: 16 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import copy
import logging
from copy import deepcopy
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch

from monai.config import DtypeLike, KeysCollection, NdarrayTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform, Randomizable
from monai.transforms.utility.array import (
AddChannel,
Expand Down Expand Up @@ -379,7 +381,7 @@ def __call__(
return d


class ToTensord(MapTransform):
class ToTensord(MapTransform, InvertibleTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ToTensor`.
"""
Expand All @@ -397,9 +399,22 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
inverse_transform = ToNumpy()
# Apply inverse
d[key] = inverse_transform(d[key])
# Remove the applied transform
self.pop_transform(d, key)
return d


class ToNumpyd(MapTransform):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def run_testsuit():
"test_ensure_channel_first",
"test_ensure_channel_firstd",
"test_handler_early_stop",
"test_handler_transform_inverter",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
81 changes: 81 additions & 0 deletions tests/test_handler_transform_inverter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest

import numpy as np
import torch
from ignite.engine import Engine

from monai.data import CacheDataset, DataLoader, create_test_image_3d
from monai.handlers import TransformInverter
from monai.transforms import (
AddChanneld,
Compose,
LoadImaged,
RandAffined,
RandAxisFlipd,
RandFlipd,
RandRotate90d,
RandRotated,
RandZoomd,
ResizeWithPadOrCropd,
ToTensord,
)
from tests.utils import make_nifti_image

KEYS = ["image", "label"]


class TestTransformInverter(unittest.TestCase):
def test_invert(self):
im_fname, seg_fname = [make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)]
transform = Compose(
[
LoadImaged(KEYS),
AddChanneld(KEYS),
RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]),
RandAxisFlipd(KEYS, prob=0.5),
RandRotate90d(KEYS, spatial_axes=(1, 2)),
RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True),
RandRotated(KEYS, prob=0.5, range_x=np.pi),
RandAffined(KEYS, prob=0.5, rotate_range=np.pi),
ResizeWithPadOrCropd(KEYS, 100),
ToTensord(KEYS),
]
)
data = [{"image": im_fname, "label": seg_fname} for _ in range(12)]

# num workers = 0 for mac or gpu transforms
num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2

dataset = CacheDataset(data, transform=transform, progress=False)
loader = DataLoader(dataset, num_workers=num_workers, batch_size=5)

# set up engine
def _train_func(engine, batch):
self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100))
return batch

engine = Engine(_train_func)

# set up testing handler
TransformInverter(transform=transform, loader=loader, output_key="image").attach(engine)

engine.run(loader, max_epochs=1)
self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100))
for i in engine.state.output["image_inverted"]:
self.assertTupleEqual(i.shape, (1, 100, 101, 107))


if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_inverse_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
RandRotated,
RandZoomd,
ResizeWithPadOrCropd,
ToTensord,
)
from monai.utils import optional_import, set_determinism
from tests.utils import make_nifti_image
Expand Down Expand Up @@ -113,7 +114,7 @@ def test_collation(self, _, transform, collate_fn, ndim):
if collate_fn:
modified_transform = transform
else:
modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100)])
modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)])

# num workers = 0 for mac or gpu transforms
num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available() else 2
Expand Down