-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft PR] Relay IR visualizer. #8448
Changes from 2 commits
5ed350f
481a462
c352cf4
1a3eb80
b3cdb07
c2f569c
984c1b4
ce9d5b6
a6872cd
bcb4d8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
<!--- Licensed to the Apache Software Foundation (ASF) under one --> | ||
<!--- or more contributor license agreements. See the NOTICE file --> | ||
<!--- distributed with this work for additional information --> | ||
<!--- regarding copyright ownership. The ASF licenses this file --> | ||
<!--- to you under the Apache License, Version 2.0 (the --> | ||
<!--- "License"); you may not use this file except in compliance --> | ||
<!--- with the License. You may obtain a copy of the License at --> | ||
|
||
<!--- http://www.apache.org/licenses/LICENSE-2.0 --> | ||
|
||
<!--- Unless required by applicable law or agreed to in writing, --> | ||
<!--- software distributed under the License is distributed on an --> | ||
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> | ||
<!--- KIND, either express or implied. See the License for the --> | ||
<!--- specific language governing permissions and limitations --> | ||
<!--- under the License. --> | ||
|
||
|
||
# IR Visualization | ||
|
||
This tool target to visualize Relay IR. | ||
|
||
# Table of Contents | ||
1. [Requirement](#Requirement) | ||
2. [Usage](#Usage) | ||
3. [Credits](#Credits) | ||
3. [TODO](#TODO) | ||
|
||
## Requirement | ||
|
||
1. TVM | ||
2. graphviz and graphviz-dev | ||
2. bokeh==2.3.1 | ||
3. pygraphviz==1.6 | ||
4. networkx==2.5.1 | ||
|
||
``` | ||
# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html | ||
|
||
# requirements of pygraphviz | ||
apt-get install graphviz graphviz-dev | ||
# pygraphviz | ||
pip install pygraphviz==1.6 | ||
|
||
# networkx | ||
pip install networkx==2.5.1 | ||
|
||
# bokeh | ||
pip install bokeh==2.3.1 | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it worth having a look at https://github.com/apache/tvm/blob/main/python/gen_requirements.py to add there dependencies there, so that users don't need to figure out by themselves the dependencies. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @leandron ! |
||
|
||
## Usage | ||
|
||
``` | ||
from tvm.contrib import relay_viz | ||
mod, params = tvm.relay.frontend.from_onnx(net, shape_dict) | ||
vizer = relay_viz.RelayVisualizer(mod, relay_param=params) | ||
vizer.render("output.html") | ||
``` | ||
|
||
## Credits | ||
|
||
1. https://github.com/apache/tvm/pull/4370 | ||
|
||
2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm | ||
|
||
3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Relay IR Visualizer""" | ||
import logging | ||
import copy | ||
import tvm | ||
from tvm import relay | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
# TODO: add python typing hint for arguments. | ||
|
||
|
||
def _dft_render_cb(plotter, node_to_id, relay_param): | ||
"""a callback to Add nodes and edges to the plotter. | ||
|
||
Parameters | ||
---------- | ||
plotter : class plotter.Plotter | ||
|
||
node_to_id : Dict | ||
|
||
relay_param : Dict | ||
""" | ||
# Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm | ||
unknown_type = "unknown" | ||
for node, node_id in node_to_id.items(): | ||
if isinstance(node, relay.Function): | ||
plotter.node(node_id, "Func", "") | ||
plotter.edge(node_to_id[node.body], node_id) | ||
elif isinstance(node, relay.Var): | ||
name_hint = node.name_hint | ||
node_detail = "" | ||
node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" | ||
if node.type_annotation is not None: | ||
if hasattr(node.type_annotation, "shape"): | ||
shape = tuple(map(int, node.type_annotation.shape)) | ||
dtype = node.type_annotation.dtype | ||
node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( | ||
name_hint, shape, dtype | ||
) | ||
else: | ||
node_detail = str(node.type_annotation) | ||
plotter.node(node_id, node_type, node_detail) | ||
elif isinstance(node, relay.Tuple): | ||
plotter.node(node_id, "Tuple", "") | ||
for field in node.fields: | ||
plotter.edge(node_to_id[field], node_id) | ||
elif isinstance(node, relay.expr.Constant): | ||
node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) | ||
plotter.node(node_id, "Const", node_detail) | ||
elif isinstance(node, relay.expr.Call): | ||
op_name = unknown_type | ||
node_details = [] | ||
if isinstance(node.op, tvm.ir.Op): | ||
op_name = node.op.name | ||
if node.attrs: | ||
node_details = [ | ||
"{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() | ||
] | ||
elif isinstance(node.op, relay.Function): | ||
func_attrs = node.op.attrs | ||
op_name = "Anonymous Func" | ||
if func_attrs: | ||
node_details = [ | ||
"{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() | ||
] | ||
if "Composite" in func_attrs.keys(): | ||
op_name = func_attrs["Composite"] | ||
else: | ||
op_name = str(type(node.op)).split(".")[-1].split("'")[0] | ||
|
||
plotter.node(node_id, op_name, "\n".join(node_details)) | ||
args = [node_to_id[arg] for arg in node.args] | ||
for arg in args: | ||
plotter.edge(arg, node_id) | ||
elif isinstance(node, relay.expr.TupleGetItem): | ||
plotter.node(node_id, "TupleGetItem", "idx: {}".format(node.index)) | ||
plotter.edge(node_to_id[node.tuple_value], node_id) | ||
elif isinstance(node, tvm.ir.Op): | ||
pass | ||
elif isinstance(node, relay.Let): | ||
plotter.node(node_id, "Let", "") | ||
plotter.edge(node_to_id[node.value], node_id) | ||
plotter.edge(node_id, node_to_id[node.var]) | ||
else: | ||
unknown_info = "Unknown node: {}".format(type(node)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One issue I found when partitioning
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, Thanks! will handle global var. |
||
_LOGGER.warning(unknown_info) | ||
plotter.node(node_id, unknown_type, unknown_info) | ||
|
||
|
||
class PlotterBackend: | ||
"""Enumeration for available plotters.""" | ||
|
||
BOKEH = "bokeh" | ||
|
||
|
||
class RelayVisualizer: | ||
"""Relay IR Visualizer""" | ||
|
||
def __init__( | ||
self, relay_mod, relay_param=None, plotter_be=PlotterBackend.BOKEH, render_cb=_dft_render_cb | ||
): | ||
"""Visualize Relay IR. | ||
|
||
Parameters | ||
---------- | ||
relay_mod : object | ||
Relay IR module | ||
relay_param: dict | ||
Relay parameter dictionary | ||
plotter_be: PlotterBackend. | ||
The backend of plotting. Default "bokeh" | ||
render_cb: callable[Plotter, Dict, Dict] | ||
A callable accepting plotter, node_to_id, relay_param. | ||
See _dft_render_cb(plotter, node_to_id, relay_param) as | ||
an example. | ||
""" | ||
self._node_to_id = {} | ||
self._plotter = get_plotter(plotter_be) | ||
self._render_cb = render_cb | ||
self._relay_param = relay_param if relay_param is not None else {} | ||
|
||
relay.analysis.post_order_visit( | ||
relay_mod["main"], | ||
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda | ||
) | ||
|
||
def _traverse_expr(self, node): | ||
# based on https://github.com/apache/tvm/pull/4370 | ||
if node in self._node_to_id: | ||
return | ||
self._node_to_id[node] = len(self._node_to_id) | ||
|
||
def render(self, filename): | ||
# shallow copy to prevent callback modify self._node_to_id | ||
self._render_cb(self._plotter, copy.copy(self._node_to_id), self._relay_param) | ||
return self._plotter.render(filename=filename) | ||
|
||
|
||
def get_plotter(backend): | ||
if backend == PlotterBackend.BOKEH: | ||
from ._bokeh import BokehPlotter # pylint: disable=import-outside-toplevel | ||
|
||
return BokehPlotter() | ||
|
||
raise ValueError("Unknown plotter backend {}".format(backend)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: there is no
TODO
in the document.