Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualization of Relay IR #8668

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions python/tvm/contrib/relay_viz/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->

<!--- http://www.apache.org/licenses/LICENSE-2.0 -->

<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->


# IR Visualization
areusch marked this conversation as resolved.
Show resolved Hide resolved

This tool target to visualize Relay IR.

# Table of Contents
1. [Requirement](#Requirement)
2. [Usage](#Usage)
3. [Credits](#Credits)

## Requirement

1. TVM
2. graphviz
2. pydot
3. bokeh >= 2.3.1

```
# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html

# requirements of pydot
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd ideally like to add these to python/gen_requirements.py, but i'm not sure it's the best idea. bokeh in particular is pretty heavyweight. before we can do that, we'll need to split the IR parsing stuff into another python package which can be depended on from both this utility and TVM. so for now let's leave them out of gen_requirements.py.

apt-get install graphviz

# pydot and bokeh
pip install pydot bokeh==2.3.1
```

## Usage

```
from tvm.contrib import relay_viz
mod, params = tvm.relay.frontend.from_onnx(net, shape_dict)
vizer = relay_viz.RelayVisualizer(mod, relay_param=params)
vizer.render("output.html")
```

## Credits

1. https://github.com/apache/tvm/pull/4370

2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm

3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17
160 changes: 160 additions & 0 deletions python/tvm/contrib/relay_viz/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Relay IR Visualizer"""
import logging
import copy
from enum import Enum
from tvm import relay
from .plotter import Plotter
from .render_callback import RenderCallback


_LOGGER = logging.getLogger(__name__)


class PlotterBackend(Enum):
"""Enumeration for available plotters."""

BOKEH = "bokeh"
TERMINAL = "terminal"


class RelayVisualizer:
"""Relay IR Visualizer"""

def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kueitang if you have time, would be awesome to add type annotations here

"""Visualize Relay IR.

Parameters
----------
relay_mod : object
Relay IR module
relay_param: dict
Relay parameter dictionary
backend: PlotterBackend or a tuple
PlotterBackend: The backend of plotting. Default "bokeh"
Tuple: A tuple with two arguments. First is user-defined Plotter, \
the second is user-defined RenderCallback
"""

self._plotter, self._render_rules = get_plotter_and_render_rules(backend)
self._relay_param = relay_param if relay_param is not None else {}
# This field is used for book-keeping for each graph.
self._node_to_id = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use collections.OrderedDict() to explicitly show the order is important. (It's post-order of a Relay GlobalVar.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dict keeps insertion order after python3.7. Current tvm uses Python 3.6, for the CPython implementation of Python, dictionaries remember the order of items inserted.
Also Python 3.6 ends its life on 23/12/2021. Maybe we could keep it a dictionary right now.


global_vars = relay_mod.get_global_vars()
graph_names = []
# If we have main function, put it to the first.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you say why?

for gv_name in global_vars:
if gv_name.name_hint == "main":
graph_names.insert(0, gv_name.name_hint)
else:
graph_names.append(gv_name.name_hint)

for name in graph_names:
# clear previous graph
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to carry this as a class variable?

self._node_to_id = {}
relay.analysis.post_order_visit(
relay_mod[name],
lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda
)
graph = self._plotter.create_graph(name)
# shallow copy to prevent callback modify self._node_to_id
self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param)

def _traverse_expr(self, node):
# based on https://github.com/apache/tvm/pull/4370
if node in self._node_to_id:
return
self._node_to_id[node] = len(self._node_to_id)

def _render_cb(self, graph, node_to_id, relay_param):
"""a callback to Add nodes and edges to the graph.

Parameters
----------
graph : class plotter.Graph

node_to_id : Dict[relay.expr, int]

relay_param : Dict[string, NDarray]
"""
# Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm
unknown_type = "unknown"
for node, node_id in node_to_id.items():
if type(node) in self._render_rules: # pylint: disable=unidiomatic-typecheck
graph_info, edge_info = self._render_rules[type(node)](
node, relay_param, node_to_id
)
if graph_info:
graph.node(*graph_info)
for edge in edge_info:
graph.edge(*edge)
else:
unknown_info = "Unknown node: {}".format(type(node))
_LOGGER.warning(unknown_info)
graph.node(node_id, unknown_type, unknown_info)

def render(self, filename):
return self._plotter.render(filename=filename)


def get_plotter_and_render_rules(backend):
"""Specify the Plottor and its render rules

Parameters
----------
backend: PlotterBackend or a tuple
PlotterBackend: The backend of plotting. Default "bokeh"
Tuple: A tuple with two arguments. First is user-defined Plotter, \
the second is user-defined RenderCallback
"""
if type(backend) is tuple and len(backend) == 2: # pylint: disable=unidiomatic-typecheck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isinstance(backend, tuple), no? also, should assert the tuple length is 2 rather than allowing tuples of length != 2 right?

if not isinstance(backend[0], Plotter):
raise ValueError("First elemnet of the backend should be a plotter")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: element

plotter = backend[0]
if not isinstance(backend[1], RenderCallback):
raise ValueError("Second elemnet of the backend should be a callback")
render = backend[1]
render_rules = render.get_rules()
return plotter, render_rules

if backend in PlotterBackend:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than handle the good case here, handle the bad case and bail:

if backend not in PlotterBackend:
  raise ValueError(...)

then the rest of the function can un-indent.

if backend == PlotterBackend.BOKEH:
# pylint: disable=import-outside-toplevel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you late-import? add a comment explaining why if you need to do this.

from ._bokeh import (
BokehPlotter,
BokehRenderCallback,
)

plotter = BokehPlotter()
render = BokehRenderCallback()

elif backend == PlotterBackend.TERMINAL:
# pylint: disable=import-outside-toplevel
from ._terminal import (
TermPlotter,
TermRenderCallback,
)

plotter = TermPlotter()
render = TermRenderCallback()

render_rules = render.get_rules()
return plotter, render_rules

raise ValueError("Unknown plotter backend {}".format(backend))
Loading