Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft PR] Relay IR visualizer. #8448

Closed
wants to merge 10 commits into from
68 changes: 68 additions & 0 deletions python/tvm/contrib/relay_viz/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->


# IR Visualization

This tool target to visualize Relay IR.

# Table of Contents
1. [Requirement](#Requirement)
2. [Usage](#Usage)
3. [Credits](#Credits)
3. [TODO](#TODO)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: there is no TODO in the document.


## Requirement

1. TVM
2. graphviz and graphviz-dev
2. bokeh==2.3.1
3. pygraphviz==1.6
4. networkx==2.5.1

```
# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html

# requirements of pygraphviz
apt-get install graphviz graphviz-dev
# pygraphviz
pip install pygraphviz==1.6

# networkx
pip install networkx==2.5.1

# bokeh
pip install bokeh==2.3.1
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it worth having a look at https://github.com/apache/tvm/blob/main/python/gen_requirements.py to add there dependencies there, so that users don't need to figure out by themselves the dependencies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @leandron !
I will check that .py after I stripped out at least networkx dependency (though I am thinking it is possible to get rid of all graphviz-dev, pygraphviz and networkx dependency, still need to figure out DOT language.)


## Usage

```
from tvm.contrib import relay_viz
mod, params = tvm.relay.frontend.from_onnx(net, shape_dict)
vizer = relay_viz.RelayVisualizer(mod, relay_param=params)
vizer.render("output.html")
```

## Credits

1. https://github.com/apache/tvm/pull/4370

2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm

3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17

161 changes: 161 additions & 0 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay IR Visualizer"""
import logging
import copy
import tvm
from tvm import relay

_LOGGER = logging.getLogger(__name__)

# TODO: add python typing hint for arguments.


def _dft_render_cb(plotter, node_to_id, relay_param):
"""a callback to Add nodes and edges to the plotter.

Parameters
----------
plotter : class plotter.Plotter

node_to_id : Dict

relay_param : Dict
"""
# Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
unknown_type = "unknown"
for node, node_id in node_to_id.items():
if isinstance(node, relay.Function):
plotter.node(node_id, "Func", "")
plotter.edge(node_to_id[node.body], node_id)
elif isinstance(node, relay.Var):
name_hint = node.name_hint
node_detail = ""
node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)"
if node.type_annotation is not None:
if hasattr(node.type_annotation, "shape"):
shape = tuple(map(int, node.type_annotation.shape))
dtype = node.type_annotation.dtype
node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format(
name_hint, shape, dtype
)
else:
node_detail = str(node.type_annotation)
plotter.node(node_id, node_type, node_detail)
elif isinstance(node, relay.Tuple):
plotter.node(node_id, "Tuple", "")
for field in node.fields:
plotter.edge(node_to_id[field], node_id)
elif isinstance(node, relay.expr.Constant):
node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype)
plotter.node(node_id, "Const", node_detail)
elif isinstance(node, relay.expr.Call):
op_name = unknown_type
node_details = []
if isinstance(node.op, tvm.ir.Op):
op_name = node.op.name
if node.attrs:
node_details = [
"{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys()
]
elif isinstance(node.op, relay.Function):
func_attrs = node.op.attrs
op_name = "Anonymous Func"
if func_attrs:
node_details = [
"{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()
]
if "Composite" in func_attrs.keys():
op_name = func_attrs["Composite"]
else:
op_name = str(type(node.op)).split(".")[-1].split("'")[0]

plotter.node(node_id, op_name, "\n".join(node_details))
args = [node_to_id[arg] for arg in node.args]
for arg in args:
plotter.edge(arg, node_id)
elif isinstance(node, relay.expr.TupleGetItem):
plotter.node(node_id, "TupleGetItem", "idx: {}".format(node.index))
plotter.edge(node_to_id[node.tuple_value], node_id)
elif isinstance(node, tvm.ir.Op):
pass
elif isinstance(node, relay.Let):
plotter.node(node_id, "Let", "")
plotter.edge(node_to_id[node.value], node_id)
plotter.edge(node_id, node_to_id[node.var])
else:
unknown_info = "Unknown node: {}".format(type(node))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One issue I found when partitioning mobilenet_v1_1.0_224_quant.tflite with ethos-n before building it, I got:

Unknown node: <class 'tvm.ir.expr.GlobalVar'>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, Thanks! will handle global var.

_LOGGER.warning(unknown_info)
plotter.node(node_id, unknown_type, unknown_info)


class PlotterBackend:
"""Enumeration for available plotters."""

BOKEH = "bokeh"


class RelayVisualizer:
"""Relay IR Visualizer"""

def __init__(
self, relay_mod, relay_param=None, plotter_be=PlotterBackend.BOKEH, render_cb=_dft_render_cb
):
"""Visualize Relay IR.

Parameters
----------
relay_mod : object
Relay IR module
relay_param: dict
Relay parameter dictionary
plotter_be: PlotterBackend.
The backend of plotting. Default "bokeh"
render_cb: callable[Plotter, Dict, Dict]
A callable accepting plotter, node_to_id, relay_param.
See _dft_render_cb(plotter, node_to_id, relay_param) as
an example.
"""
self._node_to_id = {}
self._plotter = get_plotter(plotter_be)
self._render_cb = render_cb
self._relay_param = relay_param if relay_param is not None else {}

relay.analysis.post_order_visit(
relay_mod["main"],
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda
)

def _traverse_expr(self, node):
# based on https://github.com/apache/tvm/pull/4370
if node in self._node_to_id:
return
self._node_to_id[node] = len(self._node_to_id)

def render(self, filename):
# shallow copy to prevent callback modify self._node_to_id
self._render_cb(self._plotter, copy.copy(self._node_to_id), self._relay_param)
return self._plotter.render(filename=filename)


def get_plotter(backend):
if backend == PlotterBackend.BOKEH:
from ._bokeh import BokehPlotter # pylint: disable=import-outside-toplevel

return BokehPlotter()

raise ValueError("Unknown plotter backend {}".format(backend))
Loading