Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order things for circular graphs #761

Merged
merged 7 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiron_contrib/workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class Function(Node):
updated, and will attempt to update on initialization (after setting _all_ initial
input values).

Output is updated in the `process_run_result` inside the parent class `finish_run`
call, such that output data gets pushed after the node stops running but before
then `ran` signal fires.

Args:
node_function (callable): The function determining the behaviour of the node.
*output_labels (str): A name for each return value of the node function.
Expand Down
8 changes: 3 additions & 5 deletions pyiron_contrib/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,22 @@ def run(self) -> None:

def finish_run(self, run_output: tuple):
"""
Process the run result, then wrap up statuses etc.
Switch the node status, process the run result, then fire the ran signal.

By extracting this as a separate method, we allow the node to pass the actual
execution off to another entity and release the python process to do other
things. In such a case, this function should be registered as a callback
so that the node can finish "running" and, e.g. push its data forward when that
execution is finished.
"""
liamhuber marked this conversation as resolved.
Show resolved Hide resolved
self.running = False
try:
self.process_run_result(run_output)
self.signals.output.ran()
except Exception as e:
self.running = False
self.failed = True
raise e

self.signals.output.ran()
self.running = False

def _build_signal_channels(self) -> Signals:
signals = Signals()
signals.input.run = InputSignal("run", self, self.run)
Expand Down
76 changes: 76 additions & 0 deletions tests/integration/test_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import time
import unittest

import numpy as np

from pyiron_contrib.workflow.channels import OutputSignal
from pyiron_contrib.workflow.function import Function
from pyiron_contrib.workflow.workflow import Workflow


class TestNothing(unittest.TestCase):
def test_cyclic_graphs(self):
"""
Check that cyclic graphs run.

TODO: Update once logical switches are included in the node library
"""

@Workflow.wrap_as.single_value_node("rand")
def numpy_randint(low=0, high=20):
rand = np.random.randint(low=low, high=high)
print(f"Generating random number between {low} and {high}...{rand}!")
return rand

class GreaterThanLimitSwitch(Function):
"""
A switch class for sending signal output depending on a '>' check
applied to input
"""

def __init__(self, **kwargs):
super().__init__(self.greater_than, "value_gt_limit", **kwargs)
self.signals.output.true = OutputSignal("true", self)
self.signals.output.false = OutputSignal("false", self)

@staticmethod
def greater_than(value, limit=10):
return value > limit

def process_run_result(self, function_output):
"""
Process the output as usual, then fire signals accordingly.
"""
super().process_run_result(function_output)

if self.outputs.value_gt_limit.value:
print(f"{self.inputs.value.value} > {self.inputs.limit.value}")
self.signals.output.true()
else:
print(f"{self.inputs.value.value} <= {self.inputs.limit.value}")
self.signals.output.false()

@Workflow.wrap_as.single_value_node("sqrt")
def numpy_sqrt(value=0):
sqrt = np.sqrt(value)
print(f"sqrt({value}) = {sqrt}")
return sqrt

wf = Workflow("rand_until_big_then_sqrt")

wf.rand = numpy_randint(update_on_instantiation=False)

wf.gt_switch = GreaterThanLimitSwitch(run_on_updates=False)
wf.gt_switch.inputs.value = wf.rand

wf.sqrt = numpy_sqrt(run_on_updates=False)
wf.sqrt.inputs.value = wf.rand

wf.gt_switch.signals.input.run = wf.rand.signals.output.ran
wf.sqrt.signals.input.run = wf.gt_switch.signals.output.true
wf.rand.signals.input.run = wf.gt_switch.signals.output.false

wf.rand.update()
self.assertAlmostEqual(
np.sqrt(wf.rand.outputs.rand.value), wf.sqrt.outputs.sqrt.value, 6
)