Skip to content

Commit 75b598a

Browse files
authored
Merge pull request #3184 from HippocampusGirl/workflow-connect-performance
[ENH] Workflow connect performance
2 parents f2a5c4b + 0372982 commit 75b598a

File tree

3 files changed

+80
-7
lines changed

3 files changed

+80
-7
lines changed

.zenodo.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,11 @@
157157
"name": "De La Vega, Alejandro",
158158
"orcid": "0000-0001-9062-3778"
159159
},
160+
{
161+
"affiliation": "Charite Universitatsmedizin Berlin, Germany",
162+
"name": "Waller, Lea",
163+
"orcid": "0000-0002-3239-6957"
164+
},
160165
{
161166
"affiliation": "MIT",
162167
"name": "Kaczmarzyk, Jakub",

nipype/pipeline/engine/tests/test_workflows.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,22 @@ def test_doubleconnect():
8383
assert "Trying to connect" in str(excinfo.value)
8484

8585

86+
def test_nested_workflow_doubleconnect():
87+
# double input with nested workflows
88+
a = pe.Node(niu.IdentityInterface(fields=["a", "b"]), name="a")
89+
b = pe.Node(niu.IdentityInterface(fields=["a", "b"]), name="b")
90+
c = pe.Node(niu.IdentityInterface(fields=["a", "b"]), name="c")
91+
flow1 = pe.Workflow(name="test1")
92+
flow2 = pe.Workflow(name="test2")
93+
flow3 = pe.Workflow(name="test3")
94+
flow1.add_nodes([b])
95+
flow2.connect(a, "a", flow1, "b.a")
96+
with pytest.raises(Exception) as excinfo:
97+
flow3.connect(c, "a", flow2, "test1.b.a")
98+
assert "Some connections were not found" in str(excinfo.value)
99+
flow3.connect(c, "b", flow2, "test1.b.b")
100+
101+
86102
def test_duplicate_node_check():
87103

88104
wf = pe.Workflow(name="testidentity")

nipype/pipeline/engine/workflows.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -770,16 +770,68 @@ def _check_nodes(self, nodes):
770770
def _has_attr(self, parameter, subtype="in"):
771771
"""Checks if a parameter is available as an input or output
772772
"""
773+
hierarchy = parameter.split(".")
774+
775+
# Connecting to a workflow needs at least two values,
776+
# the name of the child node and the name of the input/output
777+
if len(hierarchy) < 2:
778+
return False
779+
780+
attrname = hierarchy.pop()
781+
nodename = hierarchy.pop()
782+
783+
def _check_is_already_connected(workflow, node, attrname):
784+
for _, _, d in workflow._graph.in_edges(nbunch=node, data=True):
785+
for cd in d["connect"]:
786+
if attrname == cd[1]:
787+
return False
788+
return True
789+
790+
targetworkflow = self
791+
while hierarchy:
792+
workflowname = hierarchy.pop(0)
793+
workflow = None
794+
for node in targetworkflow._graph.nodes():
795+
if node.name == workflowname:
796+
if isinstance(node, Workflow):
797+
workflow = node
798+
break
799+
if workflow is None:
800+
return False
801+
# Verify input does not already have an incoming connection
802+
# in the hierarchy of workflows
803+
if subtype == "in":
804+
hierattrname = ".".join(hierarchy + [nodename, attrname])
805+
if not _check_is_already_connected(
806+
targetworkflow, workflow, hierattrname):
807+
return False
808+
targetworkflow = workflow
809+
810+
targetnode = None
811+
for node in targetworkflow._graph.nodes():
812+
if node.name == nodename:
813+
if isinstance(node, Workflow):
814+
return False
815+
else:
816+
targetnode = node
817+
break
818+
if targetnode is None:
819+
return False
820+
773821
if subtype == "in":
774-
subobject = self.inputs
822+
if not hasattr(targetnode.inputs, attrname):
823+
return False
775824
else:
776-
subobject = self.outputs
777-
attrlist = parameter.split(".")
778-
cur_out = subobject
779-
for attr in attrlist:
780-
if not hasattr(cur_out, attr):
825+
if not hasattr(targetnode.outputs, attrname):
781826
return False
782-
cur_out = getattr(cur_out, attr)
827+
828+
# Verify input does not already have an incoming connection
829+
# in the target workflow
830+
if subtype == "in":
831+
if not _check_is_already_connected(
832+
targetworkflow, targetnode, attrname):
833+
return False
834+
783835
return True
784836

785837
def _get_parameter_node(self, parameter, subtype="in"):

0 commit comments

Comments
 (0)