Skip to content

Commit

Permalink
Merge pull request #633 from pyiron/introduce_control_channels
Browse files Browse the repository at this point in the history
Introduce control channels
  • Loading branch information
liamhuber authored Apr 24, 2023
2 parents a5e14cc + 09e5f0a commit afaace5
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 98 deletions.
85 changes: 84 additions & 1 deletion notebooks/workflow_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "9aa721332f2644d5a30634bf7dfac067",
"model_id": "ceee180feeb547a0b959ccacf8aef2ae",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -453,6 +453,89 @@
"wf.nodes"
]
},
{
"cell_type": "markdown",
"id": "9b9220b0-833d-4c6a-9929-5dfa60a47d14",
"metadata": {},
"source": [
"# Flow control\n",
"\n",
"By default, when a node runs and updates its output, this triggers outputs in all downstream connections. This is useful when all your node functions are small and light, but there may come times when you want something other than this simple \"push\" flow.\n",
"\n",
"In addition to input and output data channels, nodes also have \"signal\" channels available. Input signals are bound to a callback function (typically one of its node's methods), and output signals trigger the callbacks for all the input signal channels they're connected to.\n",
"\n",
"Standard nodes have a `run` input signal (which is, unsurprisingly, bound to the `run` method), and a `ran` output signal (which, again, hopefully with no great surprise, is triggered at the end of the `run` method.)\n",
"\n",
"Below is a super simple example of how these signal channels can be used to delay execution and manually control flow:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2e418abf-7059-4e1e-9b9f-b3dc0a4b5e35",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 None\n"
]
}
],
"source": [
"@node(\"y\")\n",
"def linear(x):\n",
" return x\n",
"\n",
"@node(\"z\")\n",
"def times_two(y):\n",
" return 2 * y\n",
"\n",
"l = linear(x=1)\n",
"t2 = times_two(\n",
" y=l.outputs.y, update_on_instantiation=False, run_automatically=False\n",
")\n",
"print(t2.inputs.y, t2.outputs.z)"
]
},
{
"cell_type": "markdown",
"id": "37aa4455-9b98-4be5-a365-363e3c490bb6",
"metadata": {},
"source": [
"Now the input of `t2` got updated when the connection is made, but we told this node not to do any automatic updates, so the output has its uninitialized value of `None`.\n",
"\n",
"Often, you will probably want to have nodes with data connections to have signal connections, but this is not strictly required. Here, we'll introduce a third node to control the execution of `t2`.\n",
"\n",
"Note that we have all the same syntacic sugar from data channels when creating connections between signal channels."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3310eac4-04f6-421b-9824-19bb2d680be6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2\n"
]
}
],
"source": [
"@node(\"void\")\n",
"def control():\n",
" return\n",
"\n",
"c = control()\n",
"t2.signals.input.run = c.signals.output.ran\n",
"c.run()\n",
"print(t2.outputs.z.value)"
]
},
{
"cell_type": "markdown",
"id": "2671dc36-42a4-466b-848d-067ef7bd1d1d",
Expand Down
195 changes: 147 additions & 48 deletions pyiron_contrib/workflow/channels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from warnings import warn

from pyiron_contrib.workflow.has_to_dict import HasToDict
Expand All @@ -12,16 +13,72 @@
from pyiron_contrib.workflow.node import Node


class Channel(HasToDict):
class Channel(HasToDict, ABC):
"""
Channels control the flow of data on the graph.
Channels facilitate the flow of information (data or control signals) into and
out of nodes.
They have a label and belong to a node.
They may optionally have a type hint.
They may optionally have a storage priority (but this doesn't do anything yet).
(In the future they may optionally have an ontological type.)
Input/output channels can be (dis)connected from other output/input channels, and
store all of their current connections in a list.
"""

def __init__(
self,
label: str,
node: Node,
):
self.label = label
self.node = node
self.connections = []

@abstractmethod
def __str__(self):
pass

@abstractmethod
def connect(self, *others: Channel):
pass

def disconnect(self, *others: Channel):
for other in others:
if other in self.connections:
self.connections.remove(other)
other.disconnect(self)

def disconnect_all(self):
self.disconnect(*self.connections)

@property
def connected(self):
return len(self.connections) > 0

def _already_connected(self, other: Channel):
return other in self.connections

def __iter__(self):
return self.connections.__iter__()

def __len__(self):
return len(self.connections)

def to_dict(self):
return {
"label": self.label,
"connected": self.connected,
"connections": [f"{c.node.label}.{c.label}" for c in self.connections]
}


class DataChannel(Channel, ABC):
"""
Data channels control the flow of data on the graph.
They store this data in a `value` attribute.
They may optionally have a type hint.
They have a `ready` attribute which tells whether their value matches their type
hint.
They may optionally have a storage priority (but this doesn't do anything yet).
(In the future they may optionally have an ontological type.)
The `value` held by a channel can be manually assigned, but should normally be set
by the `update` method.
Expand All @@ -31,7 +88,7 @@ class Channel(HasToDict):
connected to.
Type hinting is strictly enforced in one situation: when making connections to
other channels and at least one channel has a non-None value for its type hint.
other channels and at least one data channel has a non-None value for its type hint.
In this case, we insist that the output type hint be _as or more more specific_ than
the input type hint, to ensure that the input always receives output of a type it
expects. This behaviour can be disabled and all connections allowed by setting
Expand Down Expand Up @@ -73,14 +130,12 @@ def __init__(
storage_priority: int = 0,
strict_connections: bool = True,
):
self.label = label
self.node = node
super().__init__(label=label, node=node)
self.default = default
self.value = default
self.type_hint = type_hint
self.storage_priority = storage_priority
self.strict_connections = True
self.connections = []
self.strict_connections = strict_connections

@property
def ready(self):
Expand All @@ -89,15 +144,15 @@ def ready(self):
else:
return True

def connect(self, *others: Channel):
def connect(self, *others: DataChannel):
for other in others:
if self._valid_connection(other):
self.connections.append(other)
other.connections.append(self)
out, inp = self._figure_out_who_is_who(other)
inp.update(out.value)
else:
if isinstance(other, Channel):
if isinstance(other, DataChannel):
warn(
f"{self.label} ({self.__class__.__name__}) and {other.label} "
f"({other.__class__.__name__}) were not a valid connection"
Expand All @@ -124,51 +179,25 @@ def _valid_connection(self, other):
else:
return False

def _is_IO_pair(self, other: Channel):
return isinstance(other, Channel) and type(self) != type(other)
def _is_IO_pair(self, other: DataChannel):
return isinstance(other, DataChannel) and not isinstance(other, self.__class__)

def _already_connected(self, other: Channel):
return other in self.connections

def _both_typed(self, other: Channel):
def _both_typed(self, other: DataChannel):
return self.type_hint is not None and other.type_hint is not None

def _figure_out_who_is_who(self, other: Channel) -> (OutputChannel, InputChannel):
return (self, other) if isinstance(self, OutputChannel) else (other, self)

def disconnect(self, *others: Channel):
for other in others:
if other in self.connections:
self.connections.remove(other)
other.disconnect(self)

def disconnect_all(self):
self.disconnect(*self.connections)

@property
def connected(self):
return len(self.connections) > 0

def __iter__(self):
return self.connections.__iter__()

def __len__(self):
return len(self.connections)
def _figure_out_who_is_who(self, other: DataChannel) -> (OutputData, InputData):
return (self, other) if isinstance(self, OutputData) else (other, self)

def __str__(self):
return str(self.value)

def to_dict(self):
return {
"label": self.label,
"value": str(self.value),
"ready": self.ready,
"connected": self.connected,
"connections": [f"{c.node.label}.{c.label}" for c in self.connections]
}
d = super().to_dict()
d["value"] = self.value
d["ready"] = self.ready


class InputChannel(Channel):
class InputData(DataChannel):
def update(self, value):
self.value = value
self.node.update()
Expand All @@ -180,8 +209,78 @@ def deactivate_strict_connections(self):
self.strict_connections = False


class OutputChannel(Channel):
class OutputData(DataChannel):
def update(self, value):
self.value = value
for inp in self.connections:
inp.update(self.value)


class SignalChannel(Channel, ABC):
"""
Signal channels give the option control execution flow by triggering callback
functions.
Output channels can be called to trigger the callback functions of all input
channels to which they are connected.
"""

@abstractmethod
def __call__(self):
pass

def connect(self, *others: Channel):
for other in others:
if self._valid_connection(other):
self.connections.append(other)
other.connections.append(self)
else:
if isinstance(other, SignalChannel):
warn(
f"{self.label} ({self.__class__.__name__}) and {other.label} "
f"({other.__class__.__name__}) were not a valid connection"
)
else:
raise TypeError(
f"Can only connect two signal channels, but {self.label} "
f"({self.__class__.__name__}) got a {other} ({type(other)})"
)

def _valid_connection(self, other) -> bool:
return self._is_IO_pair(other) and not self._already_connected(other)

def _is_IO_pair(self, other) -> bool:
return isinstance(other, SignalChannel) \
and not isinstance(other, self.__class__)


class InputSignal(SignalChannel):
def __init__(
self,
label: str,
node: Node,
callback: callable,
):
super().__init__(label=label, node=node)
self.callback: callable = callback

def __call__(self):
self.callback()

def __str__(self):
return f"{self.label} runs {self.callback.__name__}"

def to_dict(self):
d = super().to_dict()
d["callback"] = self.callback.__name__
return d


class OutputSignal(SignalChannel):
def __call__(self):
for c in self.connections:
c()

def __str__(self):
return f"{self.label} activates " \
f"{[f'{c.node.label}.{c.label}' for c in self.connections]}"
Loading

0 comments on commit afaace5

Please sign in to comment.