|
13 | 13 | from typing import TYPE_CHECKING, Callable, Optional, Union
|
14 | 14 |
|
15 | 15 | import numpy as np
|
16 |
| - |
17 | 16 | from monai.config import DtypeLike
|
18 | 17 | from monai.transforms import SaveImage
|
19 | 18 | from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, exact_version, optional_import
|
@@ -119,7 +118,6 @@ def __init__(
|
119 | 118 | output_dtype=output_dtype,
|
120 | 119 | squeeze_end_dims=squeeze_end_dims,
|
121 | 120 | data_root_dir=data_root_dir,
|
122 |
| - save_batch=True, |
123 | 121 | )
|
124 | 122 | self.batch_transform = batch_transform
|
125 | 123 | self.output_transform = output_transform
|
@@ -147,5 +145,13 @@ def __call__(self, engine: Engine) -> None:
|
147 | 145 | """
|
148 | 146 | meta_data = self.batch_transform(engine.state.batch)
|
149 | 147 | engine_output = self.output_transform(engine.state.output)
|
150 |
| - self._saver(engine_output, meta_data) |
| 148 | + if isinstance(engine_output, (tuple, list)): |
| 149 | + # if the data is a list, save every item separately |
| 150 | + self._saver.save_batch = False |
| 151 | + for i, d in enumerate(engine_output): |
| 152 | + self._saver(d, {k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) |
| 153 | + else: |
| 154 | + # if the data is in shape: [batch, channel, H, W, [D]] |
| 155 | + self._saver.save_batch = True |
| 156 | + self._saver(engine_output, meta_data) |
151 | 157 | self.logger.info("saved all the model outputs into files.")
|
0 commit comments