Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.4.0 #7

Merged
merged 6 commits into from
Mar 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 47 additions & 19 deletions pyungo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def topological_sort(data):

class Node:
ID = 0
def __init__(self, fct, input_names, output_names):
def __init__(self, fct, input_names, output_names, args=None, kwargs=None):
Node.ID += 1
self._id = str(Node.ID)
self._fct = fct
self._input_names = input_names
self._args = args if args else []
self._kwargs = kwargs if kwargs else []
self._output_names = output_names

def __repr__(self):
Expand All @@ -39,16 +41,23 @@ def __repr__(self):
self._input_names, self._output_names
)

def __call__(self, args):
return self._fct(*args)
def __call__(self, args, **kwargs):
return self._fct(*args, **kwargs)

@property
def id(self):
return self._id

@property
def input_names(self):
return self._input_names
input_names = self._input_names
input_names.extend(self._args)
input_names.extend(self._kwargs)
return input_names

@property
def kwargs(self):
return self._kwargs

@property
def output_names(self):
Expand Down Expand Up @@ -82,16 +91,35 @@ def sim_outputs(self):
outputs.extend(node.output_names)
return outputs

@property
def dag(self):
""" return the ordered nodes graph """
ordered_nodes = []
for node_ids in topological_sort(self._dependencies()):
nodes = [self._get_node(node_id) for node_id in node_ids]
ordered_nodes.append(nodes)
return ordered_nodes

def _register(self, f, **kwargs):
input_names = kwargs.get('inputs')
args_names = kwargs.get('args')
kwargs_names = kwargs.get('kwargs')
output_names = kwargs.get('outputs')
self._create_node(
f, input_names, output_names, args_names, kwargs_names
)

def register(self, **kwargs):
def decorator(f):
input_names = kwargs.get('inputs')
output_names = kwargs.get('outputs')
self._create_node(f, input_names, output_names)
self._register(f, **kwargs)
return f
return decorator

def _create_node(self, fct, input_names, output_names):
node = Node(fct, input_names, output_names)
def add_node(self, function, **kwargs):
self._register(function, **kwargs)

def _create_node(self, fct, input_names, output_names, args_names, kwargs_names):
node = Node(fct, input_names, output_names, args_names, kwargs_names)
# assume that we cannot have two nodes with the same output names
for n in self._nodes:
for out_name in n.output_names:
Expand Down Expand Up @@ -139,17 +167,17 @@ def calculate(self, data):
for items in sorted_dep:
for item in items:
node = self._get_node(item)
args = node.input_names
args = [i_name for i_name in node.input_names if i_name not in node.kwargs]
data_to_pass = []
for arg in args:
data_to_pass.append(self._data[arg])
res = node(data_to_pass)
try:
iter(res)
except TypeError:
res = [res]
for i, out in enumerate(node.output_names):
self._data[out] = res[i]
if len(res) == 1:
return res[0]
kwargs_to_pass = {}
for kwarg in node.kwargs:
kwargs_to_pass[kwarg] = self._data[kwarg]
res = node(data_to_pass, **kwargs_to_pass)
if len(node.output_names) == 1:
self._data[node.output_names[0]] = res
else:
for i, out in enumerate(node.output_names):
self._data[out] = res[i]
return res
64 changes: 64 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,29 @@ def f_my_function2(c):
res = graph.calculate(data={'a': 2, 'b': 3})

assert res == -1.5
assert graph.data['e'] == -1.5


def test_simple_without_decorator():
graph = Graph()

def f_my_function(a, b):
return a + b

def f_my_function3(d, a):
return d - a

def f_my_function2(c):
return c / 10.

graph.add_node(f_my_function, inputs=['a', 'b'], outputs=['c'])
graph.add_node(f_my_function3, inputs=['d', 'a'], outputs=['e'])
graph.add_node(f_my_function2, inputs=['c'], outputs=['d'])

res = graph.calculate(data={'a': 2, 'b': 3})

assert res == -1.5
assert graph.data['e'] == -1.5


def test_multiple_outputs():
Expand All @@ -36,6 +59,7 @@ def f_my_function2(c, d):
res = graph.calculate(data={'a': 2, 'b': 3})

assert res == 11
assert graph.data['e'] == 11


def test_same_output_names():
Expand Down Expand Up @@ -119,6 +143,7 @@ def f_my_function(a, b):
res = graph.calculate(data={'a': 2, 'b': 3})

assert res == [0, 1, 3]
assert graph.data['c'] == [0, 1, 3]


def test_multiple_outputs_with_iterable():
Expand All @@ -135,3 +160,42 @@ def f_my_function(a, b):
assert graph.data['d'] == 30
assert res[0] == [0, 1, 3]
assert res[1] == 30


def test_args_kwargs():
graph = Graph()

@graph.register(
inputs=['a', 'b'],
args=['c'],
kwargs=['d'],
outputs=['e']
)
def f_my_function(a, b, *args, **kwargs):
return a + b + args[0] + kwargs['d']

res = graph.calculate(data={'a': 2, 'b': 3, 'c': 4, 'd': 5})

assert res == 14
assert graph.data['e'] == 14


def test_dag_pretty_print():
graph = Graph()

@graph.register(inputs=['a', 'b'], outputs=['c'])
def f_my_function(a, b):
return a + b

@graph.register(inputs=['d', 'a'], outputs=['e'])
def f_my_function3(d, a):
return d - a

@graph.register(inputs=['c'], outputs=['d'])
def f_my_function2(c):
return c / 10.

expected = ['f_my_function', 'f_my_function2', 'f_my_function3']
dag = graph.dag
for i, fct_name in enumerate(expected):
assert dag[i][0].fct_name == fct_name