Skip to content

Commit

Permalink
add license header
Browse files Browse the repository at this point in the history
  • Loading branch information
chiwwang committed Jul 12, 2021
1 parent 5ed350f commit 481a462
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 81 deletions.
17 changes: 17 additions & 0 deletions python/tvm/contrib/relay_viz/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->


# IR Visualization

Expand Down
60 changes: 43 additions & 17 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay IR Visualizer"""
import logging
import copy
import tvm
Expand All @@ -8,8 +24,9 @@

# TODO: add python typing hint for arguments.


def _dft_render_cb(plotter, node_to_id, relay_param):
""" a callback to Add nodes and edges to the plotter.
"""a callback to Add nodes and edges to the plotter.
Parameters
----------
Expand All @@ -34,7 +51,8 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
shape = tuple(map(int, node.type_annotation.shape))
dtype = node.type_annotation.dtype
node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format(
name_hint, shape, dtype)
name_hint, shape, dtype
)
else:
node_detail = str(node.type_annotation)
plotter.node(node_id, node_type, node_detail)
Expand All @@ -51,14 +69,16 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
if isinstance(node.op, tvm.ir.Op):
op_name = node.op.name
if node.attrs:
node_details = ["{}: {}".format(k, node.attrs.get_str(k))
for k in node.attrs.keys()]
node_details = [
"{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys()
]
elif isinstance(node.op, relay.Function):
func_attrs = node.op.attrs
op_name = "Anonymous Func"
if func_attrs:
node_details = ["{}: {}".format(k, func_attrs.get_str(k))
for k in func_attrs.keys()]
node_details = [
"{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()
]
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
else:
Expand All @@ -82,17 +102,20 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
_LOGGER.warning(unknown_info)
plotter.node(node_id, unknown_type, unknown_info)


class PlotterBackend:
"""Enumeration for available plotters."""

BOKEH = "bokeh"


class RelayVisualizer:
"""Relay IR Visualizer"""

def __init__(self,
relay_mod,
relay_param=None,
plotter_be=PlotterBackend.BOKEH,
render_cb=_dft_render_cb):
"""
def __init__(
self, relay_mod, relay_param=None, plotter_be=PlotterBackend.BOKEH, render_cb=_dft_render_cb
):
"""Visualize Relay IR.
Parameters
----------
Expand All @@ -113,7 +136,9 @@ def __init__(self,
self._relay_param = relay_param if relay_param is not None else {}

relay.analysis.post_order_visit(
relay_mod["main"], lambda node: self._traverse_expr(node))
relay_mod["main"],
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda
)

def _traverse_expr(self, node):
# based on https://github.com/apache/tvm/pull/4370
Expand All @@ -126,10 +151,11 @@ def render(self, filename):
self._render_cb(self._plotter, copy.copy(self._node_to_id), self._relay_param)
return self._plotter.render(filename=filename)


def get_plotter(backend):
if backend == PlotterBackend.BOKEH:
from ._bokeh import BokehPlotter
from ._bokeh import BokehPlotter # pylint: disable=import-outside-toplevel

return BokehPlotter()
else:
raise ValueError("Unknown plotter backend {}".format(backend))

raise ValueError("Unknown plotter backend {}".format(backend))
140 changes: 83 additions & 57 deletions python/tvm/contrib/relay_viz/_bokeh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Bokeh backend for Relay IR Visualizer."""
import os
import html
import logging
Expand All @@ -15,7 +31,7 @@
ColumnDataSource,
Text,
Rect,
#NodesAndLinkedEdges,
# NodesAndLinkedEdges,
HoverTool,
MultiLine,
Legend,
Expand All @@ -38,6 +54,7 @@


class NodeDescriptor:
"""Descriptor used by Bokeh plotter."""

def __init__(self, node_id, node_type, node_detail):
self._node_id = node_id
Expand All @@ -60,15 +77,16 @@ def detail(self):


class GraphShaper:
""" Provide the bounding-box, and node location, height, width given by pygraphviz.
"""
"""Provide the bounding-box, and node location, height, width given by pygraphviz."""

def __init__(self, nx_digraph, prog="neato", args=""):
agraph = nx.nx_agraph.to_agraph(nx_digraph)
agraph.layout(prog=prog, args=args)
self._agraph = agraph

def get_edge_path(self, sn, en):
edge = self._agraph.get_edge(sn, en)
def get_edge_path(self, start_node_id, end_node_id):
"""Get explicit path points for MultiLine."""
edge = self._agraph.get_edge(start_node_id, end_node_id)
pos_str = edge.attr["pos"]
tokens = pos_str.split(" ")
s_token = None
Expand Down Expand Up @@ -117,15 +135,14 @@ def _get_node_attr(self, node_name, attr_name, default_val):
try:
val = attr[attr_name]
except KeyError:
_LOGGER.warning("%s does not exist in node %s. "
"Use default %s",
attr_name,
node_name,
default_val)
_LOGGER.warning(
"%s does not exist in node %s. Use default %s", attr_name, node_name, default_val
)
return val


class BokehPlotter(Plotter):
"""Use Bokeh library to plot Relay IR."""

def __init__(self):
self._digraph = nx.DiGraph()
Expand Down Expand Up @@ -167,12 +184,14 @@ def render(self, filename):
graph_name = filename
filename = "{}.html".format(filename)

plot = Plot(title=graph_name,
plot_width=1600,
plot_height=900,
align="center",
sizing_mode="stretch_both",
margin=(0, 0, 0, 50))
plot = Plot(
title=graph_name,
plot_width=1600,
plot_height=900,
align="center",
sizing_mode="stretch_both",
margin=(0, 0, 0, 50),
)

layout_dom = self._create_layout_dom(plot)
self._save_html(filename, layout_dom)
Expand All @@ -197,16 +216,17 @@ def _get_type_to_color_map(self):
category20 = d3["Category20"][20]
# FIXME: a problem is, for different network we have different color
# for the same type.
all_types = list(set([v.node_type for v in self._id_to_node.values()]))
all_types = list({v.node_type for v in self._id_to_node.values()})
all_types.sort()
if len(all_types) > 20:
_LOGGER.warning(
"The number of types %d is larger than 20. "
"Some colors are re-used for different types.",
len(all_types))
len(all_types),
)
type_to_color = {}
for idx, t in enumerate(all_types):
type_to_color[t] = category20[idx%20]
type_to_color[t] = category20[idx % 20]
return type_to_color

def _add_legend(self, plot, graph, label):
Expand All @@ -217,8 +237,10 @@ def _add_legend(self, plot, graph, label):
def _add_tooltip(self, plot):

graph_name = self._get_graph_name(plot)
tooltips = [("node_type", "@label"),
("description", "@node_detail{safe}"),]
tooltips = [
("node_type", "@label"),
("description", "@node_detail{safe}"),
]
inspect_tool = WheelZoomTool()
# only render graph_name
hover_tool = HoverTool(tooltips=tooltips, names=[graph_name])
Expand All @@ -227,18 +249,22 @@ def _add_tooltip(self, plot):

def _create_node_type_toggler(self, plot, node_to_pos):

source = ColumnDataSource({
"x": [pos[0] for pos in node_to_pos.values()],
"y": [pos[1] for pos in node_to_pos.values()],
"text": [self._id_to_node[n].node_type for n in node_to_pos],
})

text_glyph = Text(x="x",
y="y",
text="text",
text_align="center",
text_baseline="middle",
text_font_size={"value": "1.5em"})
source = ColumnDataSource(
{
"x": [pos[0] for pos in node_to_pos.values()],
"y": [pos[1] for pos in node_to_pos.values()],
"text": [self._id_to_node[n].node_type for n in node_to_pos],
}
)

text_glyph = Text(
x="x",
y="y",
text="text",
text_align="center",
text_baseline="middle",
text_font_size={"value": "1.5em"},
)
node_annotation = plot.add_glyph(source, text_glyph, visible=False)

# widgets
Expand All @@ -253,34 +279,36 @@ def _create_graph(self, plot, shaper, node_to_pos):

# TODO: I want to plot the network with lower-level bokeh APIs in the future,
# which may not support NodesAndLinkedEdges() policy. So comment out here.
#graph.selection_policy = NodesAndLinkedEdges()
# graph.selection_policy = NodesAndLinkedEdges()

# edge
edge_line_width = 3
graph.edge_renderer.glyph = MultiLine(line_color="#888888", line_width=edge_line_width)
xs = []
ys = []
for e in self._digraph.edges():
x_pts, y_pts = shaper.get_edge_path(e[0], e[1])
xs.append(x_pts)
ys.append(y_pts)
graph.edge_renderer.data_source.data["xs"] = xs
graph.edge_renderer.data_source.data["ys"] = ys
x_path_list = []
y_path_list = []
for edge in self._digraph.edges():
x_pts, y_pts = shaper.get_edge_path(edge[0], edge[1])
x_path_list.append(x_pts)
y_path_list.append(y_pts)
graph.edge_renderer.data_source.data["xs"] = x_path_list
graph.edge_renderer.data_source.data["ys"] = y_path_list

# node
graph.node_renderer.glyph = Rect(
width="w", height="h", fill_color="fill_color")
graph.node_renderer.glyph = Rect(width="w", height="h", fill_color="fill_color")
graph.node_renderer.hover_glyph = Rect(
width="w", height="h", fill_color="fill_color", line_color="firebrick", line_width=3)
width="w", height="h", fill_color="fill_color", line_color="firebrick", line_width=3
)
graph.node_renderer.selection_glyph = Rect(
width="w", height="h", fill_color="fill_color", line_color="firebrick", line_width=3)
width="w", height="h", fill_color="fill_color", line_color="firebrick", line_width=3
)
graph.node_renderer.nonselection_glyph = Rect(
width="w", height="h", fill_color="fill_color")
width="w", height="h", fill_color="fill_color"
)

# decide rect size
px_per_inch = 72
rect_w = [shaper.get_node_width(n)*px_per_inch for n in node_to_pos]
rect_h = [shaper.get_node_height(n)*px_per_inch for n in node_to_pos]
rect_w = [shaper.get_node_width(n) * px_per_inch for n in node_to_pos]
rect_h = [shaper.get_node_height(n) * px_per_inch for n in node_to_pos]

# get type-color map
type_to_color = self._get_type_to_color_map()
Expand All @@ -292,7 +320,8 @@ def _create_graph(self, plot, shaper, node_to_pos):
h=rect_h,
label=[self._id_to_node[i].node_type for i in node_to_pos],
fill_color=[type_to_color[self._id_to_node[i].node_type] for i in node_to_pos],
node_detail=[self._id_to_node[i].detail for i in node_to_pos])
node_detail=[self._id_to_node[i].detail for i in node_to_pos],
)

return graph

Expand All @@ -319,14 +348,14 @@ def _create_layout_dom(self, plot):
[
[Spacer(sizing_mode="stretch_width"), node_type_toggler],
[plot],
])
]
)
return layout_dom

def _save_html(self, filename, layout_dom):

output_file(filename, title=filename)

# https://stackoverflow.com/a/62601727
template = """
{% block postamble %}
<style>
Expand All @@ -337,10 +366,7 @@ def _save_html(self, filename, layout_dom):
{% endblock %}
"""

save(layout_dom,
filename=filename,
title=filename,
template=template)
save(layout_dom, filename=filename, title=filename, template=template)

@staticmethod
def _get_graph_name(plot):
Expand Down
Loading

0 comments on commit 481a462

Please sign in to comment.