Skip to content

Commit c632b37

Browse files
committed
[DLMED] add support in SegmentationSaver handler
Signed-off-by: Nic Ma <[email protected]>
1 parent 3ad8aad commit c632b37

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

monai/handlers/segmentation_saver.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import TYPE_CHECKING, Callable, Optional, Union
1414

1515
import numpy as np
16-
1716
from monai.config import DtypeLike
1817
from monai.transforms import SaveImage
1918
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import
@@ -119,7 +118,6 @@ def __init__(
119118
output_dtype=output_dtype,
120119
squeeze_end_dims=squeeze_end_dims,
121120
data_root_dir=data_root_dir,
122-
save_batch=True,
123121
)
124122
self.batch_transform = batch_transform
125123
self.output_transform = output_transform
@@ -147,5 +145,13 @@ def __call__(self, engine: Engine) -> None:
147145
"""
148146
meta_data = self.batch_transform(engine.state.batch)
149147
engine_output = self.output_transform(engine.state.output)
150-
self._saver(engine_output, meta_data)
148+
if isinstance(engine_output, (tuple, list)):
149+
# if the data is a list, save every item separately
150+
self._saver.save_batch = False
151+
for i, d in enumerate(engine_output):
152+
self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
153+
else:
154+
# if the data is in shape: [batch, channel, H, W, [D]]
155+
self._saver.save_batch = True
156+
self._saver(engine_output, meta_data)
151157
self.logger.info("saved all the model outputs into files.")

monai/handlers/transform_inverter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,5 @@ def __call__(self, engine: Engine) -> None:
7979
transform_key: engine.state.batch[transform_key]}
8080

8181
with allow_missing_keys_mode(self.transform):
82-
engine.state.output[f"{self.output_key}_{self.postfix}"] = self.inverter(segs_dict)
82+
inverted_key = f"{self.output_key}_{self.postfix}"
83+
engine.state.output[inverted_key] = [i[self.batch_key] for i in self.inverter(segs_dict)]

0 commit comments

Comments
 (0)