Skip to content

Commit

Permalink
Merge pull request #708 from pyiron/allow_self
Browse files Browse the repository at this point in the history
Allow self in node definition
  • Loading branch information
samwaseda authored Jun 14, 2023
2 parents 96e0f26 + 77d7a7b commit adaf1ef
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 11 deletions.
22 changes: 19 additions & 3 deletions pyiron_contrib/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,12 +364,25 @@ def __init__(
if update_on_instantiation:
self.update()

@property
def _input_args(self):
return inspect.signature(self.node_function).parameters

def _build_input_channels(self, storage_priority: dict[str:int]):
channels = []
type_hints = get_type_hints(self.node_function)
parameters = inspect.signature(self.node_function).parameters

for label, value in parameters.items():
for ii, (label, value) in enumerate(self._input_args.items()):
if label == "self":
if ii == 0:
continue
else:
warnings.warn(
"`self` is used as an argument but not in the first"
" position, so it is treated as a normal function"
" argument. If it is to be treated as the node object,"
" use it as a first argument"
)
if label in self._init_keywords:
# We allow users to parse arbitrary kwargs as channel initialization
# So don't let them choose bad channel names
Expand Down Expand Up @@ -481,7 +494,10 @@ def run(self) -> None:

if self.server is None:
try:
function_output = self.node_function(**self.inputs.to_value_dict())
if "self" in self._input_args:
function_output = self.node_function(self=self, **self.inputs.to_value_dict())
else:
function_output = self.node_function(**self.inputs.to_value_dict())
except Exception as e:
self.running = False
self.failed = True
Expand Down
38 changes: 30 additions & 8 deletions tests/unit/workflow/test_node.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import TestCase, skipUnless
import unittest
from sys import version_info
from typing import Optional, Union
import warnings

from pyiron_contrib.workflow.node import (
FastNode, Node, SingleValueNode, node, single_value_node
Expand All @@ -19,8 +20,8 @@ def no_default(x, y):
return x + y + 1


@skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestNode(TestCase):
@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestNode(unittest.TestCase):
def test_defaults(self):
Node(plus_one, "y")

Expand Down Expand Up @@ -156,18 +157,35 @@ def test_statuses(self):
# self.assertFalse(n.running)
self.assertFalse(n.failed, msg="Re-running should reset failed status")


@skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestFastNode(TestCase):
def test_with_self(self):
def with_self(self, x: float) -> float:
return x + 0.1
node = Node(with_self, "output")
self.assertTrue("x" in node.inputs.labels)
self.assertFalse("self" in node.inputs.labels)
node.inputs.x = 1
node.run()
self.assertEqual(node.outputs.output.value, 1.1)
def with_messed_self(x: float, self) -> float:
return x + 0.1
with warnings.catch_warnings(record=True) as warning_list:
node = Node(with_messed_self, "output")
self.assertTrue("self" in node.inputs.labels)
self.assertEqual(len(warning_list), 1)



@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestFastNode(unittest.TestCase):
def test_instantiation(self):
has_defaults_is_ok = FastNode(plus_one, "y")

with self.assertRaises(ValueError):
missing_defaults_should_fail = FastNode(no_default, "z")


@skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestSingleValueNode(TestCase):
@unittest.skipUnless(version_info[0] == 3 and version_info[1] >= 10, "Only supported for 3.10+")
class TestSingleValueNode(unittest.TestCase):
def test_instantiation(self):
has_defaults_and_one_return = SingleValueNode(plus_one, "y")

Expand Down Expand Up @@ -310,3 +328,7 @@ def my_node(x: int = 0, y: int = 0, z: int = 0):
n.inputs.z.waiting_for_update,
msg="After the run, all three should now be waiting for updates again"
)


if __name__ == '__main__':
unittest.main()

0 comments on commit adaf1ef

Please sign in to comment.