Skip to content

Commit 0544124

Browse files
committed
enh: remove aggregate_outputs duplications, fix InputMultiPath
1 parent 774f41d commit 0544124

File tree

3 files changed

+37
-38
lines changed

3 files changed

+37
-38
lines changed

nipype/interfaces/ants/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# Local imports
88
from ... import logging, LooseVersion
99
from ..base import CommandLine, CommandLineInputSpec, traits, isdefined, PackageInfo
10+
from ...utils.imagemanip import copy_header as _copy_header
1011

1112
iflogger = logging.getLogger("nipype.interface")
1213

@@ -121,3 +122,17 @@ def set_default_num_threads(cls, num_threads):
121122
@property
122123
def version(self):
123124
return Info.version()
125+
126+
127+
class FixHeaderANTSCommand(ANTSCommand):
128+
"""Fix header if the copy_header input is on."""
129+
130+
def aggregate_outputs(self, runtime=None, needed_outputs=None):
131+
"""Overload the aggregation with header replacement, if required."""
132+
outputs = super(FixHeaderANTSCommand, self).aggregate_outputs(
133+
runtime, needed_outputs)
134+
if self.inputs.copy_header: # Fix headers
135+
_copy_header(
136+
self.inputs.op1, outputs["output_image"], keep_dtype=True
137+
)
138+
return outputs
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# AUTO-GENERATED by tools/checkspecs.py - DO NOT EDIT
2+
from ..base import FixHeaderANTSCommand
3+
4+
5+
def test_FixHeaderANTSCommand_inputs():
6+
input_map = dict(
7+
args=dict(argstr="%s",),
8+
environ=dict(nohash=True, usedefault=True,),
9+
num_threads=dict(nohash=True, usedefault=True,),
10+
)
11+
inputs = FixHeaderANTSCommand.input_spec()
12+
13+
for key, metadata in list(input_map.items()):
14+
for metakey, value in list(metadata.items()):
15+
assert getattr(inputs.traits()[key], metakey) == value

nipype/interfaces/ants/utils.py

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""ANTs' utilities."""
22
import os
3-
from ...utils.imagemanip import copy_header as _copy_header
43
from ..base import traits, isdefined, TraitedSpec, File, Str, InputMultiObject
5-
from .base import ANTSCommandInputSpec, ANTSCommand
4+
from .base import ANTSCommandInputSpec, ANTSCommand, FixHeaderANTSCommand
65

76

87
class _ImageMathInputSpec(ANTSCommandInputSpec):
@@ -68,7 +67,7 @@ class _ImageMathOuputSpec(TraitedSpec):
6867
output_image = File(exists=True, desc="output image file")
6968

7069

71-
class ImageMath(ANTSCommand):
70+
class ImageMath(FixHeaderANTSCommand):
7271
"""
7372
Operations over images.
7473
@@ -98,16 +97,6 @@ class ImageMath(ANTSCommand):
9897
input_spec = _ImageMathInputSpec
9998
output_spec = _ImageMathOuputSpec
10099

101-
def aggregate_outputs(self, runtime=None, needed_outputs=None):
102-
"""Overload the aggregation with header replacement, if required."""
103-
outputs = super(ImageMath, self).aggregate_outputs(
104-
runtime, needed_outputs)
105-
if self.inputs.copy_header: # Fix headers
106-
_copy_header(
107-
self.inputs.op1, outputs["output_image"], keep_dtype=True
108-
)
109-
return outputs
110-
111100

112101
class _ResampleImageBySpacingInputSpec(ANTSCommandInputSpec):
113102
dimension = traits.Int(
@@ -157,7 +146,7 @@ class _ResampleImageBySpacingOutputSpec(TraitedSpec):
157146
output_image = File(exists=True, desc="resampled file")
158147

159148

160-
class ResampleImageBySpacing(ANTSCommand):
149+
class ResampleImageBySpacing(FixHeaderANTSCommand):
161150
"""
162151
Resample an image with a given spacing.
163152
@@ -203,16 +192,6 @@ def _format_arg(self, name, trait_spec, value):
203192

204193
return super(ResampleImageBySpacing, self)._format_arg(name, trait_spec, value)
205194

206-
def aggregate_outputs(self, runtime=None, needed_outputs=None):
207-
"""Overload the aggregation with header replacement, if required."""
208-
outputs = super(ResampleImageBySpacing, self).aggregate_outputs(
209-
runtime, needed_outputs)
210-
if self.inputs.copy_header: # Fix headers
211-
_copy_header(
212-
self.inputs.input_image, outputs["output_image"], keep_dtype=True
213-
)
214-
return outputs
215-
216195

217196
class _ThresholdImageInputSpec(ANTSCommandInputSpec):
218197
dimension = traits.Int(
@@ -269,7 +248,7 @@ class _ThresholdImageOutputSpec(TraitedSpec):
269248
output_image = File(exists=True, desc="resampled file")
270249

271250

272-
class ThresholdImage(ANTSCommand):
251+
class ThresholdImage(FixHeaderANTSCommand):
273252
"""
274253
Apply thresholds on images.
275254
@@ -299,16 +278,6 @@ class ThresholdImage(ANTSCommand):
299278
input_spec = _ThresholdImageInputSpec
300279
output_spec = _ThresholdImageOutputSpec
301280

302-
def aggregate_outputs(self, runtime=None, needed_outputs=None):
303-
"""Overload the aggregation with header replacement, if required."""
304-
outputs = super(ThresholdImage, self).aggregate_outputs(
305-
runtime, needed_outputs)
306-
if self.inputs.copy_header: # Fix headers
307-
_copy_header(
308-
self.inputs.input_image, outputs["output_image"], keep_dtype=True
309-
)
310-
return outputs
311-
312281

313282
class _AIInputSpec(ANTSCommandInputSpec):
314283
dimension = traits.Enum(
@@ -465,7 +434,7 @@ class AverageAffineTransformInputSpec(ANTSCommandInputSpec):
465434
position=1,
466435
desc="Outputfname.txt: the name of the resulting transform.",
467436
)
468-
transforms = InputMultiPath(
437+
transforms = InputMultiObject(
469438
File(exists=True),
470439
argstr="%s",
471440
mandatory=True,
@@ -526,7 +495,7 @@ class AverageImagesInputSpec(ANTSCommandInputSpec):
526495
desc="Normalize: if true, the 2nd image is divided by its mean. "
527496
"This will select the largest image to average into.",
528497
)
529-
images = InputMultiPath(
498+
images = InputMultiObject(
530499
File(exists=True),
531500
argstr="%s",
532501
mandatory=True,
@@ -767,7 +736,7 @@ class ComposeMultiTransformInputSpec(ANTSCommandInputSpec):
767736
position=2,
768737
desc="Reference image (only necessary when output is warpfield)",
769738
)
770-
transforms = InputMultiPath(
739+
transforms = InputMultiObject(
771740
File(exists=True),
772741
argstr="%s",
773742
mandatory=True,

0 commit comments

Comments
 (0)