diff --git a/holoviews/element/graphs.py b/holoviews/element/graphs.py index 36ec08c018..60f68a53ce 100644 --- a/holoviews/element/graphs.py +++ b/holoviews/element/graphs.py @@ -405,7 +405,7 @@ def from_networkx(cls, G, positions, nodes=None, **kwargs): Returns: Graph element """ - if isinstance(positions, dict): + if not isinstance(positions, dict): positions = positions(G, **kwargs) edges = defaultdict(list) for start, end in G.edges(): @@ -428,30 +428,41 @@ def from_networkx(cls, G, positions, nodes=None, **kwargs): xdim, ydim, idim = cls.node_type.kdims[:3] if nodes: - xs, ys = zip(*[v for k, v in sorted(positions.items())]) - indices = list(nodes.dimension_values(idim)) - edges = [edge for edge in edges if edge[0] in indices and edge[1] in indices] - nodes = nodes.select(**{idim.name: [eid for e in edges for eid in e]}).sort() - nodes = nodes.add_dimension(xdim, 0, xs) - nodes = nodes.add_dimension(ydim, 1, ys).clone(new_type=cls.node_type) + node_columns = nodes.columns() + idx_dim = nodes.kdims[0].name + info_cols, values = zip(*((k, v) for k, v in node_columns.items() if k != idx_dim)) + node_info = {i: vals for i, vals in zip(node_columns[idx_dim], zip(*values))} else: - nodes = defaultdict(list) - for idx, pos in sorted(positions.items()): - x, y = pos - nodes[xdim.name].append(x) - nodes[ydim.name].append(y) - for attr, value in G.nodes[idx].items(): - if isinstance(value, (list, dict)): - continue - nodes[attr].append(value) - if isinstance(idx, tuple): - idx = str(idx) # Tuple node indexes handled as strings - nodes[idim.name].append(idx) - node_cols = sorted([k for k in edges if k not in cls.node_type.kdims - and len(nodes[k]) == len(nodes[xdim.name])]) - vdims = [str(col) if isinstance(col, int) else col for col in node_cols] - node_data = tuple(nodes[col] for col in [xdim.name, ydim.name, idim.name]+node_cols) - nodes = cls.node_type(node_data, vdims=vdims) + info_cols = [] + node_info = None + node_columns = defaultdict(list) + for idx, pos in sorted(positions.items()): + x, y = pos + node_columns[xdim.name].append(x) + node_columns[ydim.name].append(y) + for attr, value in G.nodes[idx].items(): + if isinstance(value, (list, dict)): + continue + node_columns[attr].append(value) + for i, col in enumerate(info_cols): + node_columns[col].append(node_info[idx][i]) + if isinstance(idx, tuple): + idx = str(idx) # Tuple node indexes handled as strings + node_columns[idim.name].append(idx) + node_cols = sorted([k for k in node_columns if k not in cls.node_type.kdims + and len(node_columns[k]) == len(node_columns[xdim.name])]) + vdims = [] + for col in node_cols: + if isinstance(col, int): + dim = str(col) + elif nodes is not None and col in nodes.vdims: + dim = nodes.get_dimension(col) + else: + dim = col + vdims.append(dim) + columns = [xdim.name, ydim.name, idim.name]+node_cols+list(info_cols) + node_data = tuple(node_columns[col] for col in columns) + nodes = cls.node_type(node_data, vdims=vdims) return cls((edge_data, nodes), vdims=edge_vdims) diff --git a/holoviews/tests/element/testgraphelement.py b/holoviews/tests/element/testgraphelement.py index c8f6f5a4eb..29f33f602b 100644 --- a/holoviews/tests/element/testgraphelement.py +++ b/holoviews/tests/element/testgraphelement.py @@ -129,12 +129,16 @@ def test_graph_redim_nodes(self): self.assertEqual(redimmed.nodes, graph.nodes.redim(x='x2', y='y2')) self.assertEqual(redimmed.edgepaths, graph.edgepaths.redim(x='x2', y='y2')) - @attr(optional=1) - def test_from_networkx_with_node_attrs(self): +class FromNetworkXTests(ComparisonTestCase): + + def setUp(self): try: import networkx as nx except: raise SkipTest('Test requires networkx to be installed') + + def test_from_networkx_with_node_attrs(self): + import networkx as nx G = nx.karate_club_graph() graph = Graph.from_networkx(G, nx.circular_layout) clubs = np.array([ @@ -146,41 +150,56 @@ def test_from_networkx_with_node_attrs(self): 'Officer', 'Officer', 'Officer', 'Officer']) self.assertEqual(graph.nodes.dimension_values('club'), clubs) - @attr(optional=1) + def test_from_networkx_with_invalid_node_attrs(self): + import networkx as nx + FG = nx.Graph() + FG.add_node(1, test=[]) + FG.add_node(2, test=[]) + FG.add_edge(1, 2) + graph = Graph.from_networkx(FG, nx.circular_layout) + self.assertEqual(graph.nodes.vdims, []) + self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2])) + self.assertEqual(graph.array(), np.array([(1, 2)])) + def test_from_networkx_with_edge_attrs(self): - try: - import networkx as nx - except: - raise SkipTest('Test requires networkx to be installed') + import networkx as nx FG = nx.Graph() FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)]) graph = Graph.from_networkx(FG, nx.circular_layout) self.assertEqual(graph.dimension_values('weight'), np.array([0.125, 0.75, 1.2, 0.375])) - @attr(optional=1) + def test_from_networkx_with_invalid_edge_attrs(self): + import networkx as nx + FG = nx.Graph() + FG.add_weighted_edges_from([(1,2,[]), (1,3,[]), (2,4,[]), (3,4,[])]) + graph = Graph.from_networkx(FG, nx.circular_layout) + self.assertEqual(graph.vdims, []) + def test_from_networkx_only_nodes(self): - try: - import networkx as nx - except: - raise SkipTest('Test requires networkx to be installed') + import networkx as nx G = nx.Graph() G.add_nodes_from([1, 2, 3]) graph = Graph.from_networkx(G, nx.circular_layout) self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3])) - @attr(optional=1) def test_from_networkx_custom_nodes(self): - try: - import networkx as nx - except: - raise SkipTest('Test requires networkx to be installed') + import networkx as nx FG = nx.Graph() FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)]) nodes = Dataset([(1, 'A'), (2, 'B'), (3, 'A'), (4, 'B')], 'index', 'some_attribute') graph = Graph.from_networkx(FG, nx.circular_layout, nodes=nodes) self.assertEqual(graph.nodes.dimension_values('some_attribute'), np.array(['A', 'B', 'A', 'B'])) + def test_from_networkx_dictionary_positions(self): + import networkx as nx + G = nx.Graph() + G.add_nodes_from([1, 2, 3]) + positions = nx.circular_layout(G) + graph = Graph.from_networkx(G, positions) + self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3])) + + class ChordTests(ComparisonTestCase): def setUp(self):