Skip to content

Commit

Permalink
Further improvements and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 25, 2019
1 parent 501de7d commit 6485c35
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 41 deletions.
59 changes: 35 additions & 24 deletions holoviews/element/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def from_networkx(cls, G, positions, nodes=None, **kwargs):
Returns:
Graph element
"""
if isinstance(positions, dict):
if not isinstance(positions, dict):
positions = positions(G, **kwargs)
edges = defaultdict(list)
for start, end in G.edges():
Expand All @@ -428,30 +428,41 @@ def from_networkx(cls, G, positions, nodes=None, **kwargs):

xdim, ydim, idim = cls.node_type.kdims[:3]
if nodes:
xs, ys = zip(*[v for k, v in sorted(positions.items())])
indices = list(nodes.dimension_values(idim))
edges = [edge for edge in edges if edge[0] in indices and edge[1] in indices]
nodes = nodes.select(**{idim.name: [eid for e in edges for eid in e]}).sort()
nodes = nodes.add_dimension(xdim, 0, xs)
nodes = nodes.add_dimension(ydim, 1, ys).clone(new_type=cls.node_type)
node_columns = nodes.columns()
idx_dim = nodes.kdims[0].name
info_cols, values = zip(*((k, v) for k, v in node_columns.items() if k != idx_dim))
node_info = {i: vals for i, vals in zip(node_columns[idx_dim], zip(*values))}
else:
nodes = defaultdict(list)
for idx, pos in sorted(positions.items()):
x, y = pos
nodes[xdim.name].append(x)
nodes[ydim.name].append(y)
for attr, value in G.nodes[idx].items():
if isinstance(value, (list, dict)):
continue
nodes[attr].append(value)
if isinstance(idx, tuple):
idx = str(idx) # Tuple node indexes handled as strings
nodes[idim.name].append(idx)
node_cols = sorted([k for k in edges if k not in cls.node_type.kdims
and len(nodes[k]) == len(nodes[xdim.name])])
vdims = [str(col) if isinstance(col, int) else col for col in node_cols]
node_data = tuple(nodes[col] for col in [xdim.name, ydim.name, idim.name]+node_cols)
nodes = cls.node_type(node_data, vdims=vdims)
info_cols = []
node_info = None
node_columns = defaultdict(list)
for idx, pos in sorted(positions.items()):
x, y = pos
node_columns[xdim.name].append(x)
node_columns[ydim.name].append(y)
for attr, value in G.nodes[idx].items():
if isinstance(value, (list, dict)):
continue
node_columns[attr].append(value)
for i, col in enumerate(info_cols):
node_columns[col].append(node_info[idx][i])
if isinstance(idx, tuple):
idx = str(idx) # Tuple node indexes handled as strings
node_columns[idim.name].append(idx)
node_cols = sorted([k for k in node_columns if k not in cls.node_type.kdims
and len(node_columns[k]) == len(node_columns[xdim.name])])
vdims = []
for col in node_cols:
if isinstance(col, int):
dim = str(col)
elif nodes is not None and col in nodes.vdims:
dim = nodes.get_dimension(col)
else:
dim = col
vdims.append(dim)
columns = [xdim.name, ydim.name, idim.name]+node_cols+list(info_cols)
node_data = tuple(node_columns[col] for col in columns)
nodes = cls.node_type(node_data, vdims=vdims)
return cls((edge_data, nodes), vdims=edge_vdims)


Expand Down
53 changes: 36 additions & 17 deletions holoviews/tests/element/testgraphelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,16 @@ def test_graph_redim_nodes(self):
self.assertEqual(redimmed.nodes, graph.nodes.redim(x='x2', y='y2'))
self.assertEqual(redimmed.edgepaths, graph.edgepaths.redim(x='x2', y='y2'))

@attr(optional=1)
def test_from_networkx_with_node_attrs(self):
class FromNetworkXTests(ComparisonTestCase):

def setUp(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')

def test_from_networkx_with_node_attrs(self):
import networkx as nx
G = nx.karate_club_graph()
graph = Graph.from_networkx(G, nx.circular_layout)
clubs = np.array([
Expand All @@ -146,41 +150,56 @@ def test_from_networkx_with_node_attrs(self):
'Officer', 'Officer', 'Officer', 'Officer'])
self.assertEqual(graph.nodes.dimension_values('club'), clubs)

@attr(optional=1)
def test_from_networkx_with_invalid_node_attrs(self):
import networkx as nx
FG = nx.Graph()
FG.add_node(1, test=[])
FG.add_node(2, test=[])
FG.add_edge(1, 2)
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.nodes.vdims, [])
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2]))
self.assertEqual(graph.array(), np.array([(1, 2)]))

def test_from_networkx_with_edge_attrs(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)])
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.dimension_values('weight'), np.array([0.125, 0.75, 1.2, 0.375]))

@attr(optional=1)
def test_from_networkx_with_invalid_edge_attrs(self):
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,[]), (1,3,[]), (2,4,[]), (3,4,[])])
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.vdims, [])

def test_from_networkx_only_nodes(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
G = nx.Graph()
G.add_nodes_from([1, 2, 3])
graph = Graph.from_networkx(G, nx.circular_layout)
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3]))

@attr(optional=1)
def test_from_networkx_custom_nodes(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)])
nodes = Dataset([(1, 'A'), (2, 'B'), (3, 'A'), (4, 'B')], 'index', 'some_attribute')
graph = Graph.from_networkx(FG, nx.circular_layout, nodes=nodes)
self.assertEqual(graph.nodes.dimension_values('some_attribute'), np.array(['A', 'B', 'A', 'B']))

def test_from_networkx_dictionary_positions(self):
import networkx as nx
G = nx.Graph()
G.add_nodes_from([1, 2, 3])
positions = nx.circular_layout(G)
graph = Graph.from_networkx(G, positions)
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3]))



class ChordTests(ComparisonTestCase):

def setUp(self):
Expand Down

0 comments on commit 6485c35

Please sign in to comment.