-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #761 from pyiron/order_things_for_circular
- Loading branch information
Showing
3 changed files
with
83 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import time | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
from pyiron_contrib.workflow.channels import OutputSignal | ||
from pyiron_contrib.workflow.function import Function | ||
from pyiron_contrib.workflow.workflow import Workflow | ||
|
||
|
||
class TestNothing(unittest.TestCase): | ||
def test_cyclic_graphs(self): | ||
""" | ||
Check that cyclic graphs run. | ||
TODO: Update once logical switches are included in the node library | ||
""" | ||
|
||
@Workflow.wrap_as.single_value_node("rand") | ||
def numpy_randint(low=0, high=20): | ||
rand = np.random.randint(low=low, high=high) | ||
print(f"Generating random number between {low} and {high}...{rand}!") | ||
return rand | ||
|
||
class GreaterThanLimitSwitch(Function): | ||
""" | ||
A switch class for sending signal output depending on a '>' check | ||
applied to input | ||
""" | ||
|
||
def __init__(self, **kwargs): | ||
super().__init__(self.greater_than, "value_gt_limit", **kwargs) | ||
self.signals.output.true = OutputSignal("true", self) | ||
self.signals.output.false = OutputSignal("false", self) | ||
|
||
@staticmethod | ||
def greater_than(value, limit=10): | ||
return value > limit | ||
|
||
def process_run_result(self, function_output): | ||
""" | ||
Process the output as usual, then fire signals accordingly. | ||
""" | ||
super().process_run_result(function_output) | ||
|
||
if self.outputs.value_gt_limit.value: | ||
print(f"{self.inputs.value.value} > {self.inputs.limit.value}") | ||
self.signals.output.true() | ||
else: | ||
print(f"{self.inputs.value.value} <= {self.inputs.limit.value}") | ||
self.signals.output.false() | ||
|
||
@Workflow.wrap_as.single_value_node("sqrt") | ||
def numpy_sqrt(value=0): | ||
sqrt = np.sqrt(value) | ||
print(f"sqrt({value}) = {sqrt}") | ||
return sqrt | ||
|
||
wf = Workflow("rand_until_big_then_sqrt") | ||
|
||
wf.rand = numpy_randint(update_on_instantiation=False) | ||
|
||
wf.gt_switch = GreaterThanLimitSwitch(run_on_updates=False) | ||
wf.gt_switch.inputs.value = wf.rand | ||
|
||
wf.sqrt = numpy_sqrt(run_on_updates=False) | ||
wf.sqrt.inputs.value = wf.rand | ||
|
||
wf.gt_switch.signals.input.run = wf.rand.signals.output.ran | ||
wf.sqrt.signals.input.run = wf.gt_switch.signals.output.true | ||
wf.rand.signals.input.run = wf.gt_switch.signals.output.false | ||
|
||
wf.rand.update() | ||
self.assertAlmostEqual( | ||
np.sqrt(wf.rand.outputs.rand.value), wf.sqrt.outputs.sqrt.value, 6 | ||
) |