Skip to content

Commit

Permalink
Support multiple GlobalVar. One global var, one graph
Browse files Browse the repository at this point in the history
  • Loading branch information
chiwwang committed Aug 21, 2021
1 parent c2f569c commit 260c9bb
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 80 deletions.
93 changes: 61 additions & 32 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
_LOGGER = logging.getLogger(__name__)


def _dft_render_cb(plotter, node_to_id, relay_param):
"""a callback to Add nodes and edges to the plotter.
def _dft_render_cb(graph, node_to_id, relay_param):
"""a callback to Add nodes and edges to the graph.
Parameters
----------
plotter : class plotter.Plotter
graph : class plotter.Graph
node_to_id : Dict[relay.expr, int]
Expand All @@ -38,8 +38,19 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
unknown_type = "unknown"
for node, node_id in node_to_id.items():
if isinstance(node, relay.Function):
plotter.node(node_id, "Func", "")
plotter.edge(node_to_id[node.body], node_id)
node_details = []
name = ""
func_attrs = node.attrs
if func_attrs:
node_details = [
"{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()
]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
name = func_attrs["Composite"]

graph.node(node_id, f"Func {name}", "\n".join(node_details))
graph.edge(node_to_id[node.body], node_id)
elif isinstance(node, relay.Var):
name_hint = node.name_hint
node_detail = ""
Expand All @@ -55,19 +66,19 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
node_detail = "name_hint: {}\ntype_annotation: {}".format(
name_hint, node.type_annotation
)
plotter.node(node_id, node_type, node_detail)
graph.node(node_id, node_type, node_detail)
elif isinstance(node, relay.GlobalVar):
name_hint = node.name_hint
node_detail = "name_hint: {}\n".format(name_hint)
node_type = "GlobalVar"
plotter.node(node_id, node_type, node_detail)
# Dont render this because GlobalVar is put to another graph.
pass
elif isinstance(node, relay.Tuple):
plotter.node(node_id, "Tuple", "")
graph.node(node_id, "Tuple", "")
for field in node.fields:
plotter.edge(node_to_id[field], node_id)
graph.edge(node_to_id[field], node_id)
elif isinstance(node, relay.expr.Constant):
node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype)
plotter.node(node_id, "Const", node_detail)
node_detail = "shape: {}, dtype: {}, str(node): {}".format(
node.data.shape, node.data.dtype, str(node)
)
graph.node(node_id, "Const", node_detail)
elif isinstance(node, relay.expr.Call):
op_name = unknown_type
node_details = []
Expand All @@ -84,28 +95,32 @@ def _dft_render_cb(plotter, node_to_id, relay_param):
node_details = [
"{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()
]
# "Composite" might from relay.transform.MergeComposite
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
elif isinstance(node.op, relay.GlobalVar):
op_name = "GlobalVar"
node_details = [f"name_hint: {node.op.name_hint}"]
else:
op_name = str(type(node.op)).split(".")[-1].split("'")[0]

plotter.node(node_id, op_name, "\n".join(node_details))
graph.node(node_id, op_name, "\n".join(node_details))
args = [node_to_id[arg] for arg in node.args]
for arg in args:
plotter.edge(arg, node_id)
graph.edge(arg, node_id)
elif isinstance(node, relay.expr.TupleGetItem):
plotter.node(node_id, "TupleGetItem", "idx: {}".format(node.index))
plotter.edge(node_to_id[node.tuple_value], node_id)
graph.node(node_id, "TupleGetItem", "idx: {}".format(node.index))
graph.edge(node_to_id[node.tuple_value], node_id)
elif isinstance(node, tvm.ir.Op):
pass
elif isinstance(node, relay.Let):
plotter.node(node_id, "Let", "")
plotter.edge(node_to_id[node.value], node_id)
plotter.edge(node_id, node_to_id[node.var])
graph.node(node_id, "Let", "")
graph.edge(node_to_id[node.value], node_id)
graph.edge(node_id, node_to_id[node.var])
else:
unknown_info = "Unknown node: {}".format(type(node))
_LOGGER.warning(unknown_info)
plotter.node(node_id, unknown_type, unknown_info)
graph.node(node_id, unknown_type, unknown_info)


class PlotterBackend:
Expand All @@ -130,20 +145,36 @@ def __init__(
Relay parameter dictionary
plotter_be: PlotterBackend.
The backend of plotting. Default "bokeh"
render_cb: callable[Plotter, Dict, Dict]
A callable accepting plotter, node_to_id, relay_param.
See _dft_render_cb(plotter, node_to_id, relay_param) as
render_cb: callable[plotter.Graph, Dict, Dict]
A callable accepting plotter.Graph, node_to_id, relay_param.
See _dft_render_cb(graph, node_to_id, relay_param) as
an example.
"""
self._node_to_id = {}
self._plotter = get_plotter(plotter_be)
self._render_cb = render_cb
self._relay_param = relay_param if relay_param is not None else {}
# This field is used for book-keeping for each graph.
self._node_to_id = {}

relay.analysis.post_order_visit(
relay_mod["main"],
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda
)
global_vars = relay_mod.get_global_vars()
graph_names = []
# If we have main function, put it to the first.
for gv_name in global_vars:
if gv_name.name_hint == "main":
graph_names.insert(0, gv_name.name_hint)
else:
graph_names.append(gv_name.name_hint)

for name in graph_names:
# clear previous graph
self._node_to_id = {}
relay.analysis.post_order_visit(
relay_mod[name],
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda
)
graph = self._plotter.create_graph(name)
# shallow copy to prevent callback modify self._node_to_id
self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)

def _traverse_expr(self, node):
# based on https://github.com/apache/tvm/pull/4370
Expand All @@ -152,8 +183,6 @@ def _traverse_expr(self, node):
self._node_to_id[node] = len(self._node_to_id)

def render(self, filename):
# shallow copy to prevent callback modify self._node_to_id
self._render_cb(self._plotter, copy.copy(self._node_to_id), self._relay_param)
return self._plotter.render(filename=filename)


Expand Down
114 changes: 68 additions & 46 deletions python/tvm/contrib/relay_viz/_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Bokeh backend for Relay IR Visualizer."""
import os
import html
import logging
import functools
Expand Down Expand Up @@ -43,8 +42,12 @@
from bokeh.palettes import (
d3,
)
from bokeh.layouts import column

from .plotter import Plotter
from .plotter import (
Plotter,
Graph,
)

_LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -174,7 +177,7 @@ def _get_node_attr(self, node_name, attr_name, default_val):
return val


class BokehPlotter(Plotter):
class BokehGraph(Graph):
"""Use Bokeh library to plot Relay IR."""

def __init__(self):
Expand All @@ -195,25 +198,19 @@ def edge(self, id_start, id_end):
id_start, id_end = str(id_start), str(id_end)
self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end))

def render(self, filename):

if filename.endswith(".html"):
graph_name = os.path.splitext(os.path.basename(filename))[0]
else:
graph_name = filename
filename = "{}.html".format(filename)
def render(self, plot):
"""Render a bokeh.models.Plot with provided nodes/edges."""

plot = Plot(
title=graph_name,
width=1600,
height=900,
align="center",
sizing_mode="stretch_both",
margin=(0, 0, 0, 50),
shaper = GraphShaper(
self._pydot_digraph,
prog="dot",
args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"],
)

layout_dom = self._create_layout_dom(plot)
self._save_html(filename, layout_dom)
self._create_graph(plot, shaper)

self._add_scalable_glyph(plot, shaper)
return plot

def _get_type_to_color_map(self):
category20 = d3["Category20"][20]
Expand Down Expand Up @@ -287,7 +284,8 @@ def cnvt_to_html(s):
renderer.selection_glyph = Rect(
fill_color=type_to_color[label], line_color="firebrick", line_width=3
)
# Though it is called "muted_glyph", we actually use it to emphasize nodes in this renderer.
# Though it is called "muted_glyph", we actually use it
# to emphasize nodes in this renderer.
renderer.muted_glyph = Rect(
fill_color=type_to_color[label], line_color="firebrick", line_width=3
)
Expand All @@ -304,7 +302,8 @@ def cnvt_to_html(s):
inactive_fill_alpha=0.2,
)
legend.click_policy = "mute"
plot.add_layout(legend, "left")
legend.location = "top_right"
plot.add_layout(legend)

# add tooltips
tooltips = [
Expand Down Expand Up @@ -345,34 +344,36 @@ def populate_detail(n_type, n_detail):
text="text",
text_align="center",
text_baseline="middle",
text_font_size={"value": "11px"},
text_font_size={"value": "14px"},
)
node_annotation = plot.add_glyph(text_source, text_glyph)

def get_scatter_loc(xs, xe, ys, ye, end_node):
def get_scatter_loc(x_start, x_end, y_start, y_end, end_node):
"""return x, y, angle as a tuple"""
node_x, node_y = shaper.get_node_pos(end_node)
node_w = shaper.get_node_width(end_node)
node_h = shaper.get_node_height(end_node)

# only 4 direction
if xe - xs > 0:
return node_x - node_w / 2, ye, -np.pi / 2
if xe - xs < 0:
return node_x + node_w / 2, ye, np.pi / 2
if ye - ys < 0:
return xe, node_y + node_h / 2, np.pi
return xe, node_y - node_h / 2, 0
if x_end - x_start > 0:
return node_x - node_w / 2, y_end, -np.pi / 2
if x_end - x_start < 0:
return node_x + node_w / 2, y_end, np.pi / 2
if y_end - y_start < 0:
return x_end, node_y + node_h / 2, np.pi
return x_end, node_y - node_h / 2, 0

scatter_source = {"x": [], "y": [], "angle": []}
for edge in self._pydot_digraph.get_edges():
id_start = edge.get_source()
id_end = edge.get_destination()
x_pts, y_pts = shaper.get_edge_path(id_start, id_end)
x, y, angle = get_scatter_loc(x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end)
x_loc, y_loc, angle = get_scatter_loc(
x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end
)
scatter_source["angle"].append(angle)
scatter_source["x"].append(x)
scatter_source["y"].append(y)
scatter_source["x"].append(x_loc)
scatter_source["y"].append(y_loc)

scatter_glyph = Scatter(
x="x",
Expand Down Expand Up @@ -424,18 +425,43 @@ def get_scatter_loc(xs, xe, ys, ye, end_node):
),
)

def _create_layout_dom(self, plot):
@staticmethod
def _get_graph_name(plot):
return plot.title

shaper = GraphShaper(
self._pydot_digraph,
prog="dot",
args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"],
)

self._create_graph(plot, shaper)
class BokehPlotter(Plotter):
"""Use Bokeh library to plot Relay IR."""

self._add_scalable_glyph(plot, shaper)
return plot
def __init__(self):
self._name_to_graph = {}

def create_graph(self, name):
if name in self._name_to_graph:
_LOGGER.warning("Graph name %s exists. ")
else:
self._name_to_graph[name] = BokehGraph()
return self._name_to_graph[name]

def render(self, filename):

if not filename.endswith(".html"):
filename = "{}.html".format(filename)

dom_list = []
for name, graph in self._name_to_graph.items():
plot = Plot(
title=name,
width=1600,
height=900,
align="center",
margin=(0, 0, 0, 70),
)

dom = graph.render(plot)
dom_list.append(dom)

self._save_html(filename, column(*dom_list))

def _save_html(self, filename, layout_dom):

Expand All @@ -452,7 +478,3 @@ def _save_html(self, filename, layout_dom):
"""

save(layout_dom, filename=filename, title=filename, template=template)

@staticmethod
def _get_graph_name(plot):
return plot.title
24 changes: 22 additions & 2 deletions python/tvm/contrib/relay_viz/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import abc


class Plotter(abc.ABC):
"""Abstract class for plotters.
class Graph(abc.ABC):
"""Abstract class for graph.
Implement this interface for various graph libraries.
"""
Expand Down Expand Up @@ -53,6 +53,26 @@ def edge(self, id_start, id_end):
the ID to the ending node.
"""


class Plotter(abc.ABC):
"""Abstract class for plotters.
Implement this interface for various graph libraries.
"""

@abc.abstractmethod
def create_graph(self, name):
"""Create a graph
Parameters
----------
name : string, the name of the graph
Return
------
Graph instance.
"""

@abc.abstractmethod
def render(self, filename):
"""Render the graph as a file.
Expand Down

0 comments on commit 260c9bb

Please sign in to comment.