Skip to content

Commit

Permalink
Merge pull request #761 from pyiron/order_things_for_circular
Browse files Browse the repository at this point in the history
  • Loading branch information
liamhuber authored Jul 11, 2023
2 parents 82fdcfd + d0dac28 commit 491d855
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
4 changes: 4 additions & 0 deletions pyiron_contrib/workflow/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class Function(Node):
updated, and will attempt to update on initialization (after setting _all_ initial
input values).
Output is updated in the `process_run_result` inside the parent class `finish_run`
call, such that output data gets pushed after the node stops running but before
then `ran` signal fires.
Args:
node_function (callable): The function determining the behaviour of the node.
*output_labels (str): A name for each return value of the node function.
Expand Down
8 changes: 3 additions & 5 deletions pyiron_contrib/workflow/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,22 @@ def run(self) -> None:

def finish_run(self, run_output: tuple):
"""
Process the run result, then wrap up statuses etc.
Switch the node status, process the run result, then fire the ran signal.
By extracting this as a separate method, we allow the node to pass the actual
execution off to another entity and release the python process to do other
things. In such a case, this function should be registered as a callback
so that the node can finish "running" and, e.g. push its data forward when that
execution is finished.
"""
self.running = False
try:
self.process_run_result(run_output)
self.signals.output.ran()
except Exception as e:
self.running = False
self.failed = True
raise e

self.signals.output.ran()
self.running = False

def _build_signal_channels(self) -> Signals:
signals = Signals()
signals.input.run = InputSignal("run", self, self.run)
Expand Down
76 changes: 76 additions & 0 deletions tests/integration/test_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import time
import unittest

import numpy as np

from pyiron_contrib.workflow.channels import OutputSignal
from pyiron_contrib.workflow.function import Function
from pyiron_contrib.workflow.workflow import Workflow


class TestNothing(unittest.TestCase):
def test_cyclic_graphs(self):
"""
Check that cyclic graphs run.
TODO: Update once logical switches are included in the node library
"""

@Workflow.wrap_as.single_value_node("rand")
def numpy_randint(low=0, high=20):
rand = np.random.randint(low=low, high=high)
print(f"Generating random number between {low} and {high}...{rand}!")
return rand

class GreaterThanLimitSwitch(Function):
"""
A switch class for sending signal output depending on a '>' check
applied to input
"""

def __init__(self, **kwargs):
super().__init__(self.greater_than, "value_gt_limit", **kwargs)
self.signals.output.true = OutputSignal("true", self)
self.signals.output.false = OutputSignal("false", self)

@staticmethod
def greater_than(value, limit=10):
return value > limit

def process_run_result(self, function_output):
"""
Process the output as usual, then fire signals accordingly.
"""
super().process_run_result(function_output)

if self.outputs.value_gt_limit.value:
print(f"{self.inputs.value.value} > {self.inputs.limit.value}")
self.signals.output.true()
else:
print(f"{self.inputs.value.value} <= {self.inputs.limit.value}")
self.signals.output.false()

@Workflow.wrap_as.single_value_node("sqrt")
def numpy_sqrt(value=0):
sqrt = np.sqrt(value)
print(f"sqrt({value}) = {sqrt}")
return sqrt

wf = Workflow("rand_until_big_then_sqrt")

wf.rand = numpy_randint(update_on_instantiation=False)

wf.gt_switch = GreaterThanLimitSwitch(run_on_updates=False)
wf.gt_switch.inputs.value = wf.rand

wf.sqrt = numpy_sqrt(run_on_updates=False)
wf.sqrt.inputs.value = wf.rand

wf.gt_switch.signals.input.run = wf.rand.signals.output.ran
wf.sqrt.signals.input.run = wf.gt_switch.signals.output.true
wf.rand.signals.input.run = wf.gt_switch.signals.output.false

wf.rand.update()
self.assertAlmostEqual(
np.sqrt(wf.rand.outputs.rand.value), wf.sqrt.outputs.sqrt.value, 6
)

0 comments on commit 491d855

Please sign in to comment.