diff --git a/pyiron_workflow/mixin/semantics.py b/pyiron_workflow/mixin/semantics.py index de083b87..e207ab92 100644 --- a/pyiron_workflow/mixin/semantics.py +++ b/pyiron_workflow/mixin/semantics.py @@ -13,9 +13,10 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from difflib import get_close_matches from pathlib import Path +from typing import Generic, TypeVar from bidict import bidict @@ -31,12 +32,12 @@ class Semantic(UsesState, HasLabel, HasParent, ABC): accessible. """ - semantic_delimiter = "/" + semantic_delimiter: str = "/" def __init__( self, label: str, *args, parent: SemanticParent | None = None, **kwargs ): - self._label = None + self._label = "" self._parent = None self._detached_parent_path = None self.label = label @@ -61,6 +62,13 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: + self._set_parent(new_parent) + + def _set_parent(self, new_parent: SemanticParent | None): + """ + mypy is uncooperative with super calls for setters, so we pull the behaviour + out. + """ if new_parent is self._parent: # Exit early if nothing is changing return @@ -157,7 +165,10 @@ class CyclicPathError(ValueError): """ -class SemanticParent(Semantic, ABC): +ChildType = TypeVar("ChildType", bound=Semantic) + + +class SemanticParent(Semantic, Generic[ChildType], ABC): """ A semantic object with a collection of uniquely-named semantic children. @@ -182,19 +193,29 @@ def __init__( strict_naming: bool = True, **kwargs, ): - self._children = bidict() + self._children: bidict[str, ChildType] = bidict() self.strict_naming = strict_naming super().__init__(*args, label=label, parent=parent, **kwargs) + @classmethod + @abstractmethod + def child_type(cls) -> type[ChildType]: + # Dev note: In principle, this could be a regular attribute + # However, in other situations this is precluded (e.g. in channels) + # since it would result in circular references. + # Here we favour consistency over brevity, + # and maintain the X_type() class method pattern + pass + @property - def children(self) -> bidict[str, Semantic]: + def children(self) -> bidict[str, ChildType]: return self._children @property def child_labels(self) -> tuple[str]: return tuple(child.label for child in self) - def __getattr__(self, key): + def __getattr__(self, key) -> ChildType: try: return self._children[key] except KeyError as key_error: @@ -210,7 +231,7 @@ def __getattr__(self, key): def __iter__(self): return self.children.values().__iter__() - def __len__(self): + def __len__(self) -> int: return len(self.children) def __dir__(self): @@ -218,15 +239,15 @@ def __dir__(self): def add_child( self, - child: Semantic, + child: ChildType, label: str | None = None, strict_naming: bool | None = None, - ) -> Semantic: + ) -> ChildType: """ Add a child, optionally assigning it a new label in the process. Args: - child (Semantic): The child to add. + child (ChildType): The child to add. label (str|None): A (potentially) new label to assign the child. (Default is None, leave the child's label alone.) strict_naming (bool|None): Whether to append a suffix to the label if @@ -234,7 +255,7 @@ def add_child( use the class-level flag.) Returns: - (Semantic): The child being added. + (ChildType): The child being added. Raises: TypeError: When the child is not of an allowed class. @@ -244,18 +265,12 @@ def add_child( `strict_naming` is true. """ - if not isinstance(child, Semantic): + if not isinstance(child, self.child_type()): raise TypeError( - f"{self.label} expected a new child of type {Semantic.__name__} " + f"{self.label} expected a new child of type {self.child_type()} " f"but got {child}" ) - if isinstance(child, ParentMost): - raise ParentMostError( - f"{child.label} is {ParentMost.__name__} and may only take None as a " - f"parent but was added as a child to {self.label}" - ) - self._ensure_path_is_not_cyclic(self, child) self._ensure_child_has_no_other_parent(child) @@ -339,15 +354,15 @@ def _add_suffix_to_label(self, label): ) return new_label - def remove_child(self, child: Semantic | str) -> Semantic: + def remove_child(self, child: ChildType | str) -> ChildType: if isinstance(child, str): child = self.children.pop(child) - elif isinstance(child, Semantic): + elif isinstance(child, self.child_type()): self.children.inv.pop(child) else: raise TypeError( f"{self.label} expected to remove a child of type str or " - f"{Semantic.__name__} but got {child}" + f"{self.child_type()} but got {child}" ) child.parent = None @@ -361,7 +376,7 @@ def parent(self) -> SemanticParent | None: @parent.setter def parent(self, new_parent: SemanticParent | None) -> None: self._ensure_path_is_not_cyclic(new_parent, self) - super(SemanticParent, type(self)).parent.__set__(self, new_parent) + self._set_parent(new_parent) def __getstate__(self): state = super().__getstate__() @@ -396,27 +411,3 @@ def __setstate__(self, state): # children). So, now return their parent to them: for child in self: child.parent = self - - -class ParentMostError(TypeError): - """ - To be raised when assigning a parent to a parent-most object - """ - - -class ParentMost(SemanticParent, ABC): - """ - A semantic parent that cannot have any other parent. - """ - - @property - def parent(self) -> None: - return None - - @parent.setter - def parent(self, new_parent: None): - if new_parent is not None: - raise ParentMostError( - f"{self.label} is {ParentMost.__name__} and may only take None as a " - f"parent but got {type(new_parent)}" - ) diff --git a/pyiron_workflow/nodes/composite.py b/pyiron_workflow/nodes/composite.py index 11d50583..e3a06cab 100644 --- a/pyiron_workflow/nodes/composite.py +++ b/pyiron_workflow/nodes/composite.py @@ -54,7 +54,7 @@ class FailedChildError(RuntimeError): """Raise when one or more child nodes raise exceptions.""" -class Composite(SemanticParent, HasCreator, Node, ABC): +class Composite(SemanticParent[Node], HasCreator, Node, ABC): """ A base class for nodes that have internal graph structure -- i.e. they hold a collection of child nodes and their computation is to execute that graph. @@ -154,6 +154,10 @@ def __init__( **kwargs, ) + @classmethod + def child_type(cls) -> type[Node]: + return Node + def activate_strict_hints(self): super().activate_strict_hints() for node in self: @@ -420,8 +424,6 @@ def executor_shutdown(self, wait=True, *, cancel_futures=False): def __setattr__(self, key: str, node: Node): if isinstance(node, Composite) and key in ["_parent", "parent"]: # This is an edge case for assigning a node to an attribute - # We either defer to the setter with super, or directly assign the private - # variable (as requested in the setter) super().__setattr__(key, node) elif isinstance(node, Node): self.add_child(node, label=key) diff --git a/pyiron_workflow/workflow.py b/pyiron_workflow/workflow.py index 791e17c8..8a5ddb29 100644 --- a/pyiron_workflow/workflow.py +++ b/pyiron_workflow/workflow.py @@ -11,7 +11,6 @@ from bidict import bidict from pyiron_workflow.io import Inputs, Outputs -from pyiron_workflow.mixin.semantics import ParentMost from pyiron_workflow.nodes.composite import Composite if TYPE_CHECKING: @@ -20,7 +19,13 @@ from pyiron_workflow.storage import StorageInterface -class Workflow(ParentMost, Composite): +class ParentMostError(TypeError): + """ + To be raised when assigning a parent to a parent-most object + """ + + +class Workflow(Composite): """ Workflows are a dynamic composite node -- i.e. they hold and run a collection of nodes (a subgraph) which can be dynamically modified (adding and removing nodes, @@ -495,3 +500,15 @@ def replace_child( raise e return owned_node + + @property + def parent(self) -> None: + return None + + @parent.setter + def parent(self, new_parent: None): + if new_parent is not None: + raise ParentMostError( + f"{self.label} is a {self.__class__} and may only take None as a " + f"parent but got {type(new_parent)}" + ) diff --git a/tests/unit/mixin/test_semantics.py b/tests/unit/mixin/test_semantics.py index dd40c5f0..0b63b94f 100644 --- a/tests/unit/mixin/test_semantics.py +++ b/tests/unit/mixin/test_semantics.py @@ -3,18 +3,23 @@ from pyiron_workflow.mixin.semantics import ( CyclicPathError, - ParentMost, Semantic, SemanticParent, ) +class ConcreteParent(SemanticParent[Semantic]): + @classmethod + def child_type(cls) -> type[Semantic]: + return Semantic + + class TestSemantics(unittest.TestCase): def setUp(self): - self.root = ParentMost("root") + self.root = ConcreteParent("root") self.child1 = Semantic("child1", parent=self.root) - self.middle1 = SemanticParent("middle", parent=self.root) - self.middle2 = SemanticParent("middle_sub", parent=self.middle1) + self.middle1 = ConcreteParent("middle", parent=self.root) + self.middle2 = ConcreteParent("middle_sub", parent=self.middle1) self.child2 = Semantic("child2", parent=self.middle2) def test_getattr(self): @@ -58,18 +63,6 @@ def test_parent(self): self.assertEqual(self.child1.parent, self.root) self.assertEqual(self.root.parent, None) - with self.subTest(f"{ParentMost.__name__} exceptions"): - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't have parent" - ): - self.root.parent = SemanticParent(label="foo") - - with self.assertRaises( - TypeError, msg=f"{ParentMost.__name__} instances can't be children" - ): - some_parent = SemanticParent(label="bar") - some_parent.add_child(self.root) - with self.subTest("Cyclicity exceptions"): with self.assertRaises(CyclicPathError): self.middle1.parent = self.middle2 diff --git a/tests/unit/test_workflow.py b/tests/unit/test_workflow.py index 174013d3..f19032b7 100644 --- a/tests/unit/test_workflow.py +++ b/tests/unit/test_workflow.py @@ -9,9 +9,8 @@ from pyiron_workflow._tests import ensure_tests_in_python_path from pyiron_workflow.channels import NOT_DATA -from pyiron_workflow.mixin.semantics import ParentMostError from pyiron_workflow.storage import TypeNotFoundError, available_backends -from pyiron_workflow.workflow import Workflow +from pyiron_workflow.workflow import ParentMostError, Workflow ensure_tests_in_python_path() @@ -155,7 +154,7 @@ def test_io_map_bijectivity(self): self.assertEqual(3, len(wf.inputs_map), msg="All entries should be stored") self.assertEqual(0, len(wf.inputs), msg="No IO should be left exposed") - def test_is_parentmost(self): + def test_takes_no_parent(self): wf = Workflow("wf") wf2 = Workflow("wf2")