From af7361be656263d76a1588a3863a30c325f6a8ef Mon Sep 17 00:00:00 2001 From: kueitang Date: Fri, 9 Jul 2021 13:34:18 +0800 Subject: [PATCH 01/16] [Visualization Relay IR on terminal] -Add a AST dump pass -It provides a snap shot to the relay IR graph --- python/tvm/contrib/retv.py | 307 ++++++++++++++++++++++++++++++ tests/python/contrib/test_retv.py | 223 ++++++++++++++++++++++ 2 files changed, 530 insertions(+) create mode 100755 python/tvm/contrib/retv.py create mode 100755 tests/python/contrib/test_retv.py diff --git a/python/tvm/contrib/retv.py b/python/tvm/contrib/retv.py new file mode 100755 index 000000000000..7b0d75b912b2 --- /dev/null +++ b/python/tvm/contrib/retv.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay Expression Terminal Visualization (RETV), visualizing Relay Expression on Terminal""" +from tvm import relay +from tvm import ir +from tvm.relay import Tuple + + +class Node: + """Base unit of a relay IR visualization node. + + Parameters + ---------- + expr : expr + Relay IR expression. + + name : str + The name of the relay IR node. + + parent : Node + The parent node of the relay IR node. + + is_last : bool + Whether the node is the last within same level nodes. + """ + + def __init__(self, expr, name, parent, is_last): + self.expr = expr + self.name = name + self.parent = parent + self.is_last = is_last + self.children = [] + self.prefix = "" + + +@ir.transform.module_pass(opt_level=1) +class ASTVisualization: + """To visualize the relay IR graph on terminal.""" + + def __init__(self): + self.output = [] + + def get_output(self): + """ + Returns + ------- + output: str + The graph. + """ + output = "== The AST view of the IRModule is ==\n" + for subout in self.output[1:]: + output += subout + "\n" + output += self.output[0] + "\n" # "main" function + return output + + def transform_module(self, mod, ctx): + """A module pass""" + + class ASTVisitor(relay.ExprVisitor): + """ + A visitor over Expr. + + It traverses the AST recursively, and each node information into a sequence. + """ + + def __init__(self): + super(ASTVisitor, self).__init__() + self.sequence = [] + self.parent_stack = [] + self.last_stack = [] + self.current_subgraph = "" + + def seen_node(self, new_node, expr): + """Record those seen expression""" + self.sequence.append(new_node) + self.parent_stack.append(new_node) + for expr_child in self.memo_map[expr].children: + new_node = Node( + expr=expr_child, + name=self.memo_map[expr_child].name, + parent=self.parent_stack[-1], + is_last=self.memo_map[expr_child].is_last, + ) + self.seen_node(new_node, expr_child) + self.parent_stack.pop() + + def visit(self, expr): + if expr in self.memo_map: + new_node = Node( + expr=expr, + name=self.memo_map[expr].name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.seen_node(new_node, expr) + else: + super(ASTVisitor, self).visit(expr) + + def visit_tuple(self, tup): + node = Node( + expr=tup, + name="(tuple)", + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(tup) + self.parent_stack.append(node) + for i, x in enumerate(tup.fields): + if i == len(tup.fields) - 1: + self.last_stack.append(True) + else: + self.last_stack.append(False) + self.visit(x) + self.last_stack.pop() + self.parent_stack.pop() + return node + + def visit_var(self, var): + node = Node( + expr=var, + name=var.name_hint, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(var) + return node + + def visit_function(self, fn): + if len(self.sequence) == 0: # entry function call + layer_name = "@" + self.current_subgraph + "(" + str(fn.params) + ")" + self.parent_stack = [None] + self.last_stack = [True] + else: + layer_name = "Function_" + str(fn.__hash__()) + "(" + str(fn.params) + ")" + + node = Node( + expr=fn, + name=layer_name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + if node.parent is not None: + node.parent.children.append(fn) + + is_last = True + self.last_stack.append(is_last) + self.parent_stack.append(node) + self.visit(fn.body) + self.parent_stack.pop() + self.last_stack.pop() + return node + + def visit_call(self, call): + layer_name = "(call)" + node = Node( + expr=call, + name=layer_name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(call) + self.parent_stack.append(node) + self.last_stack.append(len(call.args) == 0) + self.visit(call.op) + self.last_stack.pop() + + for i, arg in enumerate(call.args): + is_last = i == len(call.args) - 1 + self.last_stack.append(is_last) + self.visit(arg) + self.last_stack.pop() + self.parent_stack.pop() + return node + + def visit_constant(self, const): + node = Node( + expr=const, + name=const, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(const) + return node + + def visit_if(self, i): + layer_name = "if(cond, true, false)" + node = Node( + expr=i, + name=layer_name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + node.parent.children.append(node) + self.sequence.append(node) + self.parent_stack.append(node) + self.last_stack.append(False) + self.visit(i.cond) + self.last_stack[-1] = False + self.visit(i.true_branch) + self.last_stack[-1] = True + self.visit(i.false_branch) + self.last_stack.pop() + self.parent_stack.pop() + return node + + def visit_let(self, let): + layer_name = "let(var, val, body)" + node = Node( + expr=let, + name=layer_name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(let) + self.parent_stack.append(node) + self.last_stack.append(False) + self.visit(let.var) + self.last_stack[-1] = False + self.visit(let.value) + self.last_stack[-1] = True + self.visit(let.body) + self.last_stack.pop() + self.parent_stack.pop() + return node + + def visit_global_var(self, gv): + layer_name = "@" + str(gv.name_hint) + node = Node( + expr=gv, + name=layer_name, + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(gv) + return node + + def visit_op(self, op): + node = Node( + expr=op, + name=str(op.name), + parent=self.parent_stack[-1], + is_last=self.last_stack[-1], + ) + self.sequence.append(node) + node.parent.children.append(op) + return node + + def prettyprint(self): + """Prettyprint the result""" + + if len(self.sequence) <= 1: + raise RuntimeError("It is an empty IRmodule") + res = "" + res += self.sequence[0].name + "\n" + for node in self.sequence[1:]: + if node.parent is None: + part_a = "" + part_b = "" + else: + part_a = node.parent.prefix[:-3] + part_b = " " * 3 if node.parent.is_last else "| " + part_c = "`--" if node.is_last else "|--" + if isinstance(node.expr, Tuple): + name = "" + for child in node.children: + name += str(self.memo_map[child].name) + ", " + name = "(" + name[:-2] + ")" + node.name = name + node.prefix = part_a + part_b + part_c + res += node.prefix + str(node.name) + "\n" + return res + + printer = ASTVisitor() + printer.current_subgraph = "main" + printer.visit(mod["main"]) + self.output.append(printer.prettyprint()) + for subgraph in mod.get_global_vars(): + name = subgraph.name_hint + if name != "main": + printer.sequence = [] + printer.parent_stack = [] + printer.last_stack = [] + printer.current_subgraph = name + printer.visit(mod[name]) + self.output.append(printer.prettyprint()) + return mod diff --git a/tests/python/contrib/test_retv.py b/tests/python/contrib/test_retv.py new file mode 100755 index 000000000000..f2588ea65470 --- /dev/null +++ b/tests/python/contrib/test_retv.py @@ -0,0 +1,223 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.contrib import retv +import re +import numpy as np +import pytest + + +def test_model_A(): + x = relay.var("x") + y = relay.const(1) + z = relay.add(x, y) + z = relay.multiply(x, z) + mod = tvm.ir.IRModule().from_expr(z) + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([Var(x)])\n" + " `--(call)\n" + " |--multiply\n" + " |--x\n" + " `--(call)\n" + " |--add\n" + " |--x\n" + " `--1\n\n" + ) + assert res == golden + + +def test_tuple(): + x = relay.const(1) + y = relay.const(2) + z = relay.add(x, y) + t = relay.Tuple([x, z]) + mod = tvm.ir.IRModule().from_expr(t) + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([])\n" + " `--(1, (call))\n" + " |--1\n" + " `--(call)\n" + " |--add\n" + " |--1\n" + " `--2\n\n" + ) + assert res == golden + + +def test_function(): + mod = tvm.IRModule() + x = relay.var("x", shape=(2,)) + y = relay.var("y", shape=(2,)) + f = relay.Function(relay.analysis.free_vars(x + y), x + y) + mod["main"] = relay.Function(relay.analysis.free_vars(f), f) + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([])\n" + " `--Function_28372544([Var(x, ty=TensorType([2], float32)), Var(y, ty=TensorType([2], float32))])\n" + " `--(call)\n" + " |--add\n" + " |--x\n" + " `--y\n\n" + ) + match = re.search(r"Function_.\d*\(", res) + if match: + res = res.replace(match.group(0), "Function_28372544(") + else: + assert False + assert res == golden + + +def test_if(): + cond = relay.Var("cond") + left = relay.Var("left") + right = relay.Var("right") + ife = relay.If(cond, left, right) + mod = tvm.ir.IRModule().from_expr(ife) + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([Var(cond), Var(left), Var(right)])\n" + " `--if(cond, true, false)\n" + " |--cond\n" + " |--left\n" + " `--right\n\n" + ) + assert res == golden + + +def test_global_var(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.nn.relu(y) + mod = tvm.IRModule() + foo = relay.GlobalVar("foo") + mod[foo] = relay.Function([x, weight], y) + mod = transform.InferType()(mod) + mod["main"] = relay.Function([x, weight], foo(x, weight)) + mod = transform.InferType()(mod) + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@foo([Var(x, ty=TensorType([1, 64, 56, 56], float32)), Var(weight, ty=TensorType([64, 64, 3, 3], float32))])\n" + " `--(call)\n" + " |--nn.relu\n" + " `--(call)\n" + " |--nn.conv2d\n" + " |--x\n" + " `--weight\n\n" + "@main([Var(x, ty=TensorType([1, 64, 56, 56], float32)), Var(weight, ty=TensorType([64, 64, 3, 3], float32))])\n" + " `--(call)\n" + " |--@foo\n" + " |--x\n" + " `--weight\n\n" + ) + assert res == golden + + +def test_loop(): + from tvm.relay.loops import while_loop + + i = relay.var("i") + + def _cond(i): + return relay.less(i, relay.const(10)) + + def _body(i): + x = i + relay.const(1) + return (x,) + + loop = while_loop(_cond, [i], _body) + body = loop(relay.const(2)) + func = relay.Function([], body) + mod = tvm.IRModule() + mod["main"] = func + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([])\n" + " `--(call)\n" + " |--let(var, val, body)\n" + " | |--while_loop\n" + " | |--Function_34141568([Var(i)])\n" + " | | `--if(cond, true, false)\n" + " | | |--(call)\n" + " | | | |--less\n" + " | | | |--i\n" + " | | | `--10\n" + " | | |--(call)\n" + " | | | |--while_loop\n" + " | | | `--(call)\n" + " | | | |--add\n" + " | | | |--i\n" + " | | | `--1\n" + " | | `--()\n" + " | | `--i\n" + " | `--while_loop\n" + " `--2\n\n" + ) + match = re.search(r"Function_.\d*\(", res) + if match: + res = res.replace(match.group(0), "Function_34141568(") + else: + assert False + assert res == golden + + +def test_where(): + x = relay.const(np.array([[1, 2], [3, 4]]), dtype="int64") + y = relay.const(np.array([[5, 6], [7, 8]]), dtype="int64") + condition = relay.const(np.array([[1], [0]]), dtype="int64") + where = relay.where(condition, x, y) + mod = tvm.IRModule().from_expr(where) + + viz_res = retv.ASTVisualization() + mod = viz_res(mod) + res = viz_res.get_output() + golden = ( + "== The AST view of the IRModule is ==\n" + "@main([])\n" + " `--(call)\n" + " |--where\n" + " |--meta[relay.Constant][0]\n\n" + " |--meta[relay.Constant][0]\n\n" + " `--meta[relay.Constant][0]\n\n\n" + ) + assert res == golden + + +if __name__ == "__main__": + pytest.main([__file__]) From fcb993b268c283d5b371fed7a32684b0d51a1394 Mon Sep 17 00:00:00 2001 From: kueitang Date: Tue, 31 Aug 2021 03:35:36 +0800 Subject: [PATCH 02/16] Change the design on relay Vizualizer --- python/tvm/contrib/retv.py | 307 ------------------------------ tests/python/contrib/test_retv.py | 223 ---------------------- 2 files changed, 530 deletions(-) delete mode 100755 python/tvm/contrib/retv.py delete mode 100755 tests/python/contrib/test_retv.py diff --git a/python/tvm/contrib/retv.py b/python/tvm/contrib/retv.py deleted file mode 100755 index 7b0d75b912b2..000000000000 --- a/python/tvm/contrib/retv.py +++ /dev/null @@ -1,307 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Relay Expression Terminal Visualization (RETV), visualizing Relay Expression on Terminal""" -from tvm import relay -from tvm import ir -from tvm.relay import Tuple - - -class Node: - """Base unit of a relay IR visualization node. - - Parameters - ---------- - expr : expr - Relay IR expression. - - name : str - The name of the relay IR node. - - parent : Node - The parent node of the relay IR node. - - is_last : bool - Whether the node is the last within same level nodes. - """ - - def __init__(self, expr, name, parent, is_last): - self.expr = expr - self.name = name - self.parent = parent - self.is_last = is_last - self.children = [] - self.prefix = "" - - -@ir.transform.module_pass(opt_level=1) -class ASTVisualization: - """To visualize the relay IR graph on terminal.""" - - def __init__(self): - self.output = [] - - def get_output(self): - """ - Returns - ------- - output: str - The graph. - """ - output = "== The AST view of the IRModule is ==\n" - for subout in self.output[1:]: - output += subout + "\n" - output += self.output[0] + "\n" # "main" function - return output - - def transform_module(self, mod, ctx): - """A module pass""" - - class ASTVisitor(relay.ExprVisitor): - """ - A visitor over Expr. - - It traverses the AST recursively, and each node information into a sequence. - """ - - def __init__(self): - super(ASTVisitor, self).__init__() - self.sequence = [] - self.parent_stack = [] - self.last_stack = [] - self.current_subgraph = "" - - def seen_node(self, new_node, expr): - """Record those seen expression""" - self.sequence.append(new_node) - self.parent_stack.append(new_node) - for expr_child in self.memo_map[expr].children: - new_node = Node( - expr=expr_child, - name=self.memo_map[expr_child].name, - parent=self.parent_stack[-1], - is_last=self.memo_map[expr_child].is_last, - ) - self.seen_node(new_node, expr_child) - self.parent_stack.pop() - - def visit(self, expr): - if expr in self.memo_map: - new_node = Node( - expr=expr, - name=self.memo_map[expr].name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.seen_node(new_node, expr) - else: - super(ASTVisitor, self).visit(expr) - - def visit_tuple(self, tup): - node = Node( - expr=tup, - name="(tuple)", - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(tup) - self.parent_stack.append(node) - for i, x in enumerate(tup.fields): - if i == len(tup.fields) - 1: - self.last_stack.append(True) - else: - self.last_stack.append(False) - self.visit(x) - self.last_stack.pop() - self.parent_stack.pop() - return node - - def visit_var(self, var): - node = Node( - expr=var, - name=var.name_hint, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(var) - return node - - def visit_function(self, fn): - if len(self.sequence) == 0: # entry function call - layer_name = "@" + self.current_subgraph + "(" + str(fn.params) + ")" - self.parent_stack = [None] - self.last_stack = [True] - else: - layer_name = "Function_" + str(fn.__hash__()) + "(" + str(fn.params) + ")" - - node = Node( - expr=fn, - name=layer_name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - if node.parent is not None: - node.parent.children.append(fn) - - is_last = True - self.last_stack.append(is_last) - self.parent_stack.append(node) - self.visit(fn.body) - self.parent_stack.pop() - self.last_stack.pop() - return node - - def visit_call(self, call): - layer_name = "(call)" - node = Node( - expr=call, - name=layer_name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(call) - self.parent_stack.append(node) - self.last_stack.append(len(call.args) == 0) - self.visit(call.op) - self.last_stack.pop() - - for i, arg in enumerate(call.args): - is_last = i == len(call.args) - 1 - self.last_stack.append(is_last) - self.visit(arg) - self.last_stack.pop() - self.parent_stack.pop() - return node - - def visit_constant(self, const): - node = Node( - expr=const, - name=const, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(const) - return node - - def visit_if(self, i): - layer_name = "if(cond, true, false)" - node = Node( - expr=i, - name=layer_name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - node.parent.children.append(node) - self.sequence.append(node) - self.parent_stack.append(node) - self.last_stack.append(False) - self.visit(i.cond) - self.last_stack[-1] = False - self.visit(i.true_branch) - self.last_stack[-1] = True - self.visit(i.false_branch) - self.last_stack.pop() - self.parent_stack.pop() - return node - - def visit_let(self, let): - layer_name = "let(var, val, body)" - node = Node( - expr=let, - name=layer_name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(let) - self.parent_stack.append(node) - self.last_stack.append(False) - self.visit(let.var) - self.last_stack[-1] = False - self.visit(let.value) - self.last_stack[-1] = True - self.visit(let.body) - self.last_stack.pop() - self.parent_stack.pop() - return node - - def visit_global_var(self, gv): - layer_name = "@" + str(gv.name_hint) - node = Node( - expr=gv, - name=layer_name, - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(gv) - return node - - def visit_op(self, op): - node = Node( - expr=op, - name=str(op.name), - parent=self.parent_stack[-1], - is_last=self.last_stack[-1], - ) - self.sequence.append(node) - node.parent.children.append(op) - return node - - def prettyprint(self): - """Prettyprint the result""" - - if len(self.sequence) <= 1: - raise RuntimeError("It is an empty IRmodule") - res = "" - res += self.sequence[0].name + "\n" - for node in self.sequence[1:]: - if node.parent is None: - part_a = "" - part_b = "" - else: - part_a = node.parent.prefix[:-3] - part_b = " " * 3 if node.parent.is_last else "| " - part_c = "`--" if node.is_last else "|--" - if isinstance(node.expr, Tuple): - name = "" - for child in node.children: - name += str(self.memo_map[child].name) + ", " - name = "(" + name[:-2] + ")" - node.name = name - node.prefix = part_a + part_b + part_c - res += node.prefix + str(node.name) + "\n" - return res - - printer = ASTVisitor() - printer.current_subgraph = "main" - printer.visit(mod["main"]) - self.output.append(printer.prettyprint()) - for subgraph in mod.get_global_vars(): - name = subgraph.name_hint - if name != "main": - printer.sequence = [] - printer.parent_stack = [] - printer.last_stack = [] - printer.current_subgraph = name - printer.visit(mod[name]) - self.output.append(printer.prettyprint()) - return mod diff --git a/tests/python/contrib/test_retv.py b/tests/python/contrib/test_retv.py deleted file mode 100755 index f2588ea65470..000000000000 --- a/tests/python/contrib/test_retv.py +++ /dev/null @@ -1,223 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import tvm -from tvm import relay -from tvm.relay import transform -from tvm.contrib import retv -import re -import numpy as np -import pytest - - -def test_model_A(): - x = relay.var("x") - y = relay.const(1) - z = relay.add(x, y) - z = relay.multiply(x, z) - mod = tvm.ir.IRModule().from_expr(z) - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([Var(x)])\n" - " `--(call)\n" - " |--multiply\n" - " |--x\n" - " `--(call)\n" - " |--add\n" - " |--x\n" - " `--1\n\n" - ) - assert res == golden - - -def test_tuple(): - x = relay.const(1) - y = relay.const(2) - z = relay.add(x, y) - t = relay.Tuple([x, z]) - mod = tvm.ir.IRModule().from_expr(t) - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([])\n" - " `--(1, (call))\n" - " |--1\n" - " `--(call)\n" - " |--add\n" - " |--1\n" - " `--2\n\n" - ) - assert res == golden - - -def test_function(): - mod = tvm.IRModule() - x = relay.var("x", shape=(2,)) - y = relay.var("y", shape=(2,)) - f = relay.Function(relay.analysis.free_vars(x + y), x + y) - mod["main"] = relay.Function(relay.analysis.free_vars(f), f) - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([])\n" - " `--Function_28372544([Var(x, ty=TensorType([2], float32)), Var(y, ty=TensorType([2], float32))])\n" - " `--(call)\n" - " |--add\n" - " |--x\n" - " `--y\n\n" - ) - match = re.search(r"Function_.\d*\(", res) - if match: - res = res.replace(match.group(0), "Function_28372544(") - else: - assert False - assert res == golden - - -def test_if(): - cond = relay.Var("cond") - left = relay.Var("left") - right = relay.Var("right") - ife = relay.If(cond, left, right) - mod = tvm.ir.IRModule().from_expr(ife) - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([Var(cond), Var(left), Var(right)])\n" - " `--if(cond, true, false)\n" - " |--cond\n" - " |--left\n" - " `--right\n\n" - ) - assert res == golden - - -def test_global_var(): - x = relay.var("x", shape=(1, 64, 56, 56)) - weight = relay.var("weight", shape=(64, 64, 3, 3)) - y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) - y = relay.nn.relu(y) - mod = tvm.IRModule() - foo = relay.GlobalVar("foo") - mod[foo] = relay.Function([x, weight], y) - mod = transform.InferType()(mod) - mod["main"] = relay.Function([x, weight], foo(x, weight)) - mod = transform.InferType()(mod) - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@foo([Var(x, ty=TensorType([1, 64, 56, 56], float32)), Var(weight, ty=TensorType([64, 64, 3, 3], float32))])\n" - " `--(call)\n" - " |--nn.relu\n" - " `--(call)\n" - " |--nn.conv2d\n" - " |--x\n" - " `--weight\n\n" - "@main([Var(x, ty=TensorType([1, 64, 56, 56], float32)), Var(weight, ty=TensorType([64, 64, 3, 3], float32))])\n" - " `--(call)\n" - " |--@foo\n" - " |--x\n" - " `--weight\n\n" - ) - assert res == golden - - -def test_loop(): - from tvm.relay.loops import while_loop - - i = relay.var("i") - - def _cond(i): - return relay.less(i, relay.const(10)) - - def _body(i): - x = i + relay.const(1) - return (x,) - - loop = while_loop(_cond, [i], _body) - body = loop(relay.const(2)) - func = relay.Function([], body) - mod = tvm.IRModule() - mod["main"] = func - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([])\n" - " `--(call)\n" - " |--let(var, val, body)\n" - " | |--while_loop\n" - " | |--Function_34141568([Var(i)])\n" - " | | `--if(cond, true, false)\n" - " | | |--(call)\n" - " | | | |--less\n" - " | | | |--i\n" - " | | | `--10\n" - " | | |--(call)\n" - " | | | |--while_loop\n" - " | | | `--(call)\n" - " | | | |--add\n" - " | | | |--i\n" - " | | | `--1\n" - " | | `--()\n" - " | | `--i\n" - " | `--while_loop\n" - " `--2\n\n" - ) - match = re.search(r"Function_.\d*\(", res) - if match: - res = res.replace(match.group(0), "Function_34141568(") - else: - assert False - assert res == golden - - -def test_where(): - x = relay.const(np.array([[1, 2], [3, 4]]), dtype="int64") - y = relay.const(np.array([[5, 6], [7, 8]]), dtype="int64") - condition = relay.const(np.array([[1], [0]]), dtype="int64") - where = relay.where(condition, x, y) - mod = tvm.IRModule().from_expr(where) - - viz_res = retv.ASTVisualization() - mod = viz_res(mod) - res = viz_res.get_output() - golden = ( - "== The AST view of the IRModule is ==\n" - "@main([])\n" - " `--(call)\n" - " |--where\n" - " |--meta[relay.Constant][0]\n\n" - " |--meta[relay.Constant][0]\n\n" - " `--meta[relay.Constant][0]\n\n\n" - ) - assert res == golden - - -if __name__ == "__main__": - pytest.main([__file__]) From 6fd05e6718b53311498e6ddda947deb8a96f1afa Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 12 Jul 2021 15:51:46 +0800 Subject: [PATCH 03/16] Draft for Relay IR visualizer. * BOKEH BACKEND * Scalable Text for Node type * Add arrow-head and TODO * Use pydot. Remove networkx and pygraphviz dependency * Add interactive legend and show information based on zoom level * Support multiple GlobalVar. One global var, one graph * test reserved post order. try terminal viz by kueitang --- python/tvm/contrib/relay_viz/README.md | 61 +++ python/tvm/contrib/relay_viz/__init__.py | 92 ++++ python/tvm/contrib/relay_viz/_bokeh.py | 580 ++++++++++++++++++++++ python/tvm/contrib/relay_viz/_terminal.py | 177 +++++++ python/tvm/contrib/relay_viz/plotter.py | 82 +++ 5 files changed, 992 insertions(+) create mode 100644 python/tvm/contrib/relay_viz/README.md create mode 100644 python/tvm/contrib/relay_viz/__init__.py create mode 100644 python/tvm/contrib/relay_viz/_bokeh.py create mode 100644 python/tvm/contrib/relay_viz/_terminal.py create mode 100644 python/tvm/contrib/relay_viz/plotter.py diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md new file mode 100644 index 000000000000..5f5eb135fe80 --- /dev/null +++ b/python/tvm/contrib/relay_viz/README.md @@ -0,0 +1,61 @@ + + + + + + + + + + + + + + + + + + +# IR Visualization + +This tool target to visualize Relay IR. + +# Table of Contents +1. [Requirement](#Requirement) +2. [Usage](#Usage) +3. [Credits](#Credits) +3. [TODO](#TODO) + +## Requirement + +1. TVM +2. graphviz +2. pydot +3. bokeh >= 2.3.1 + +``` +# To install TVM, please refer to https://tvm.apache.org/docs/install/from_source.html + +# requirements of pydot +apt-get install graphviz + +# pydot and bokeh +pip install pydot bokeh==2.3.1 +``` + +## Usage + +``` +from tvm.contrib import relay_viz +mod, params = tvm.relay.frontend.from_onnx(net, shape_dict) +vizer = relay_viz.RelayVisualizer(mod, relay_param=params) +vizer.render("output.html") +``` + +## Credits + +1. https://github.com/apache/tvm/pull/4370 + +2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + +3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17 \ No newline at end of file diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py new file mode 100644 index 000000000000..eaaf86441f2e --- /dev/null +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay IR Visualizer""" +import copy +from tvm import relay + + +class PlotterBackend: + """Enumeration for available plotters.""" + + BOKEH = "bokeh" + TERMINAL = "terminal" + + +class RelayVisualizer: + """Relay IR Visualizer""" + + def __init__( + self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH + ): + """Visualize Relay IR. + + Parameters + ---------- + relay_mod : object + Relay IR module + relay_param: dict + Relay parameter dictionary + backend: PlotterBackend. + The backend of plotting. Default "bokeh" + """ + + self._plotter, self._render_cb = get_plotter_and_render_cb(backend) + self._relay_param = relay_param if relay_param is not None else {} + # This field is used for book-keeping for each graph. + self._node_to_id = {} + + global_vars = relay_mod.get_global_vars() + graph_names = [] + # If we have main function, put it to the first. + for gv_name in global_vars: + if gv_name.name_hint == "main": + graph_names.insert(0, gv_name.name_hint) + else: + graph_names.append(gv_name.name_hint) + + for name in graph_names: + # clear previous graph + self._node_to_id = {} + relay.analysis.post_order_visit( + relay_mod[name], + lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda + ) + graph = self._plotter.create_graph(name) + # shallow copy to prevent callback modify self._node_to_id + self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param) + + def _traverse_expr(self, node): + # based on https://github.com/apache/tvm/pull/4370 + if node in self._node_to_id: + return + self._node_to_id[node] = len(self._node_to_id) + + def render(self, filename): + return self._plotter.render(filename=filename) + + +def get_plotter_and_render_cb(backend): + if backend == PlotterBackend.BOKEH: + from ._bokeh import BokehPlotter, relay_render_cb # pylint: disable=import-outside-toplevel + + return BokehPlotter(), relay_render_cb + if backend == PlotterBackend.TERMINAL: + from ._terminal import TermPlotter, render_cb + + return TermPlotter(), render_cb + + raise ValueError("Unknown plotter backend {}".format(backend)) diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py new file mode 100644 index 000000000000..afbda2d47c32 --- /dev/null +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -0,0 +1,580 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Bokeh backend for Relay IR Visualizer.""" +import os +import html +import logging +import functools + +import numpy as np +import pydot + +from bokeh.io import output_file, save +from bokeh.models import ( + ColumnDataSource, + CustomJS, + Text, + Rect, + HoverTool, + MultiLine, + Legend, + Scatter, + Plot, + TapTool, + PanTool, + ResetTool, + WheelZoomTool, + SaveTool, +) +from bokeh.palettes import ( + d3, +) +from bokeh.layouts import column + +from .plotter import ( + Plotter, + Graph, +) + +import tvm +from tvm import relay + +_LOGGER = logging.getLogger(__name__) + + +def relay_render_cb(graph, node_to_id, relay_param): + """a callback to Add nodes and edges to the graph. + + Parameters + ---------- + graph : class plotter.Graph + + node_to_id : Dict[relay.expr, int] + + relay_param : Dict[string, NDarray] + """ + # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + unknown_type = "unknown" + for node, node_id in node_to_id.items(): + if isinstance(node, relay.Function): + node_details = [] + func_attrs = node.attrs + if func_attrs: + node_details = [ + "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() + ] + + graph.node(node_id, f"Func", "\n".join(node_details)) + graph.edge(node_to_id[node.body], node_id) + elif isinstance(node, relay.Var): + name_hint = node.name_hint + node_detail = "" + node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" + if node.type_annotation is not None: + if hasattr(node.type_annotation, "shape"): + shape = tuple(map(int, node.type_annotation.shape)) + dtype = node.type_annotation.dtype + node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( + name_hint, shape, dtype + ) + else: + node_detail = "name_hint: {}\ntype_annotation: {}".format( + name_hint, node.type_annotation + ) + graph.node(node_id, node_type, node_detail) + elif isinstance(node, relay.GlobalVar): + # Dont render this because GlobalVar is put to another graph. + pass + elif isinstance(node, relay.Tuple): + graph.node(node_id, "Tuple", "") + for field in node.fields: + graph.edge(node_to_id[field], node_id) + elif isinstance(node, relay.expr.Constant): + node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) + graph.node(node_id, "Const", node_detail) + elif isinstance(node, relay.expr.Call): + op_name = unknown_type + node_details = [] + if isinstance(node.op, tvm.ir.Op): + op_name = node.op.name + if node.attrs: + node_details = [ + "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() + ] + elif isinstance(node.op, relay.Function): + func_attrs = node.op.attrs + op_name = "Anonymous Func" + if func_attrs: + node_details = [ + "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() + ] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] + elif isinstance(node.op, relay.GlobalVar): + op_name = "GlobalVar" + node_details = [f"name_hint: {node.op.name_hint}"] + else: + op_name = str(type(node.op)).split(".")[-1].split("'")[0] + + graph.node(node_id, op_name, "\n".join(node_details)) + args = [node_to_id[arg] for arg in node.args] + for arg in args: + graph.edge(arg, node_id) + elif isinstance(node, relay.expr.TupleGetItem): + graph.node(node_id, "TupleGetItem", "idx: {}".format(node.index)) + graph.edge(node_to_id[node.tuple_value], node_id) + elif isinstance(node, tvm.ir.Op): + pass + elif isinstance(node, relay.Let): + graph.node(node_id, "Let", "") + graph.edge(node_to_id[node.value], node_id) + graph.edge(node_id, node_to_id[node.var]) + else: + unknown_info = "Unknown node: {}".format(type(node)) + _LOGGER.warning(unknown_info) + graph.node(node_id, unknown_type, unknown_info) + + + +class NodeDescriptor: + """Descriptor used by Bokeh plotter.""" + + def __init__(self, node_id, node_type, node_detail): + self._node_id = node_id + self._node_type = node_type + self._node_detail = node_detail + + @property + def node_id(self): + return self._node_id + + @property + def node_type(self): + return self._node_type + + @property + def detail(self): + return self._node_detail + + +class GraphShaper: + """Provide the bounding-box, and node location, height, width given by pygraphviz.""" + + # defined by graphviz. + _px_per_inch = 72 + + def __init__(self, pydot_graph, prog="dot", args=None): + if args is None: + args = [] + # call the graphviz program to get layout + pydot_graph_str = pydot_graph.create([prog] + args, format="dot").decode() + # remember original nodes + self._nodes = [n.get_name() for n in pydot_graph.get_nodes()] + # parse layout + pydot_graph = pydot.graph_from_dot_data(pydot_graph_str) + if len(pydot_graph) != 1: + # should be unlikely. + _LOGGER.warning( + "Got %d pydot graphs. Only the first one will be used.", len(pydot_graph) + ) + self._pydot_graph = pydot_graph[0] + + def get_nodes(self): + return self._nodes + + @functools.lru_cache() + def get_edge_path(self, start_node_id, end_node_id): + """Get explicit path points for MultiLine.""" + edge = self._pydot_graph.get_edge(str(start_node_id), str(end_node_id)) + if len(edge) != 1: + _LOGGER.warning( + "Got %d edges between %s and %s. Only the first one will be used.", + len(edge), + start_node_id, + end_node_id, + ) + edge = edge[0] + # filter out quotes and newline + pos_str = edge.get_pos().strip('"').replace("\\\n", "") + tokens = pos_str.split(" ") + s_token = None + e_token = None + ret_x_pts = [] + ret_y_pts = [] + for token in tokens: + if token.startswith("e,"): + e_token = token + elif token.startswith("s,"): + s_token = token + else: + x_str, y_str = token.split(",") + ret_x_pts.append(float(x_str)) + ret_y_pts.append(float(y_str)) + if s_token is not None: + _, x_str, y_str = s_token.split(",") + ret_x_pts.insert(0, float(x_str)) + ret_y_pts.insert(0, float(y_str)) + if e_token is not None: + _, x_str, y_str = e_token.split(",") + ret_x_pts.append(float(x_str)) + ret_y_pts.append(float(y_str)) + + return ret_x_pts, ret_y_pts + + @functools.lru_cache() + def get_node_pos(self, node_name): + pos_str = self._get_node_attr(node_name, "pos", "0,0") + return list(map(float, pos_str.split(","))) + + def get_node_height(self, node_name): + height_str = self._get_node_attr(node_name, "height", "20") + return float(height_str) * self._px_per_inch + + def get_node_width(self, node_name): + width_str = self._get_node_attr(node_name, "width", "20") + return float(width_str) * self._px_per_inch + + def _get_node_attr(self, node_name, attr_name, default_val): + + node = self._pydot_graph.get_node(str(node_name)) + if len(node) > 1: + _LOGGER.error( + "There are %d nodes with the name %s. Randomly choose one.", len(node), node_name + ) + if len(node) == 0: + _LOGGER.warning( + "%s does not exist in the graph. Use default %s for attribute %s", + node_name, + default_val, + attr_name, + ) + return default_val + + node = node[0] + try: + val = node.obj_dict["attributes"][attr_name].strip('"') + except KeyError: + _LOGGER.warning( + "%s don't exist in node %s. Use default %s", attr_name, node_name, default_val + ) + val = default_val + return val + + +class BokehGraph(Graph): + """Use Bokeh library to plot Relay IR.""" + + def __init__(self): + self._pydot_digraph = pydot.Dot(graph_type="digraph") + self._id_to_node = {} + + def node(self, node_id, node_type, node_detail): + # need string for pydot + node_id = str(node_id) + if node_id in self._id_to_node: + _LOGGER.warning("node_id %s already exists.", node_id) + return + self._pydot_digraph.add_node(pydot.Node(node_id, label=node_detail)) + self._id_to_node[node_id] = NodeDescriptor(node_id, node_type, node_detail) + + def edge(self, id_start, id_end): + # need string to pydot + id_start, id_end = str(id_start), str(id_end) + self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end)) + + def render(self, plot): + + shaper = GraphShaper( + self._pydot_digraph, + prog="dot", + args=["-Grankdir=TB", "-Gsplines=ortho", "-Gfontsize=14", "-Nordering=in"], + ) + + self._create_graph(plot, shaper) + + self._add_scalable_glyph(plot, shaper) + return plot + + def _get_type_to_color_map(self): + category20 = d3["Category20"][20] + # FIXME: a problem is, for different network we have different color + # for the same type. + all_types = list({v.node_type for v in self._id_to_node.values()}) + all_types.sort() + if len(all_types) > 20: + _LOGGER.warning( + "The number of types %d is larger than 20. " + "Some colors are re-used for different types.", + len(all_types), + ) + type_to_color = {} + for idx, t in enumerate(all_types): + type_to_color[t] = category20[idx % 20] + return type_to_color + + def _create_graph(self, plot, shaper): + + # Add edge first + edges = self._pydot_digraph.get_edges() + x_path_list = [] + y_path_list = [] + for edge in edges: + id_start = edge.get_source() + id_end = edge.get_destination() + x_pts, y_pts = shaper.get_edge_path(id_start, id_end) + x_path_list.append(x_pts) + y_path_list.append(y_pts) + + multi_line_source = ColumnDataSource({"xs": x_path_list, "ys": y_path_list}) + edge_line_color = "#888888" + edge_line_width = 3 + multi_line_glyph = MultiLine(line_color=edge_line_color, line_width=edge_line_width) + plot.add_glyph(multi_line_source, multi_line_glyph) + + # Then add nodes + type_to_color = self._get_type_to_color_map() + + def cnvt_to_html(s): + return html.escape(s).replace("\n", "
") + + label_to_ids = {} + for node_id in shaper.get_nodes(): + label = self._id_to_node[node_id].node_type + if label not in label_to_ids: + label_to_ids[label] = [] + label_to_ids[label].append(node_id) + + renderers = [] + legend_itmes = [] + for label, id_list in label_to_ids.items(): + source = ColumnDataSource( + { + "x": [shaper.get_node_pos(n)[0] for n in id_list], + "y": [shaper.get_node_pos(n)[1] for n in id_list], + "width": [shaper.get_node_width(n) for n in id_list], + "height": [shaper.get_node_height(n) for n in id_list], + "node_detail": [cnvt_to_html(self._id_to_node[n].detail) for n in id_list], + "node_type": [label] * len(id_list), + } + ) + glyph = Rect(fill_color=type_to_color[label]) + renderer = plot.add_glyph(source, glyph) + # set glyph for interactivity + renderer.nonselection_glyph = Rect(fill_color=type_to_color[label]) + renderer.hover_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + renderer.selection_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + # Though it is called "muted_glyph", we actually use it + # to emphasize nodes in this renderer. + renderer.muted_glyph = Rect( + fill_color=type_to_color[label], line_color="firebrick", line_width=3 + ) + name = f"{self._get_graph_name(plot)}_{label}" + renderer.name = name + renderers.append(renderer) + legend_itmes.append((label, [renderer])) + + # add legend + legend = Legend( + items=legend_itmes, + title="Click to highlight", + inactive_fill_color="firebrick", + inactive_fill_alpha=0.2, + ) + legend.click_policy = "mute" + legend.location = "top_right" + plot.add_layout(legend) + + # add tooltips + tooltips = [ + ("node_type", "@node_type"), + ("description", "@node_detail{safe}"), + ] + inspect_tool = WheelZoomTool() + # only render nodes + hover_tool = HoverTool(tooltips=tooltips, renderers=renderers) + plot.add_tools(PanTool(), TapTool(), inspect_tool, hover_tool, ResetTool(), SaveTool()) + plot.toolbar.active_scroll = inspect_tool + + def _add_scalable_glyph(self, plot, shaper): + nodes = shaper.get_nodes() + + def populate_detail(n_type, n_detail): + if n_detail: + return f"{n_type}\n{n_detail}" + return n_type + + text_source = ColumnDataSource( + { + "x": [shaper.get_node_pos(n)[0] for n in nodes], + "y": [shaper.get_node_pos(n)[1] for n in nodes], + "text": [self._id_to_node[n].node_type for n in nodes], + "detail": [ + populate_detail(self._id_to_node[n].node_type, self._id_to_node[n].detail) + for n in nodes + ], + "box_w": [shaper.get_node_width(n) for n in nodes], + "box_h": [shaper.get_node_height(n) for n in nodes], + } + ) + + text_glyph = Text( + x="x", + y="y", + text="text", + text_align="center", + text_baseline="middle", + text_font_size={"value": "14px"}, + ) + node_annotation = plot.add_glyph(text_source, text_glyph) + + def get_scatter_loc(x_start, x_end, y_start, y_end, end_node): + """return x, y, angle as a tuple""" + node_x, node_y = shaper.get_node_pos(end_node) + node_w = shaper.get_node_width(end_node) + node_h = shaper.get_node_height(end_node) + + # only 4 direction + if x_end - x_start > 0: + return node_x - node_w / 2, y_end, -np.pi / 2 + if x_end - x_start < 0: + return node_x + node_w / 2, y_end, np.pi / 2 + if y_end - y_start < 0: + return x_end, node_y + node_h / 2, np.pi + return x_end, node_y - node_h / 2, 0 + + scatter_source = {"x": [], "y": [], "angle": []} + for edge in self._pydot_digraph.get_edges(): + id_start = edge.get_source() + id_end = edge.get_destination() + x_pts, y_pts = shaper.get_edge_path(id_start, id_end) + x_loc, y_loc, angle = get_scatter_loc(x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end) + scatter_source["angle"].append(angle) + scatter_source["x"].append(x_loc) + scatter_source["y"].append(y_loc) + + scatter_glyph = Scatter( + x="x", + y="y", + angle="angle", + size=5, + marker="triangle", + fill_color="#AAAAAA", + fill_alpha=0.8, + ) + edge_end_arrow = plot.add_glyph(ColumnDataSource(scatter_source), scatter_glyph) + + plot.y_range.js_on_change( + "start", + CustomJS( + args=dict( + plot=plot, + node_annotation=node_annotation, + text_source=text_source, + edge_end_arrow=edge_end_arrow, + ), + code=""" + // fontsize is in px + var fontsize = 14 + // ratio = data_point/px + var ratio = (this.end - this.start)/plot.height + var text_list = text_source.data["text"] + var detail_list = text_source.data["detail"] + var box_h_list = text_source.data["box_h"] + for(var i = 0; i < text_list.length; i++) { + var line_num = Math.floor((box_h_list[i]/ratio) / (fontsize*1.5)) + if(line_num <= 0) { + // relieve for the first line + if(Math.floor((box_h_list[i]/ratio) / (fontsize)) > 0) { + line_num = 1 + } + } + var lines = detail_list[i].split("\\n") + lines = lines.slice(0, line_num) + text_list[i] = lines.join("\\n") + } + text_source.change.emit() + + node_annotation.glyph.text_font_size = {value: `${fontsize}px`} + + var new_scatter_size = Math.round(fontsize / ratio) + edge_end_arrow.glyph.size = {value: new_scatter_size} + """, + ), + ) + + @staticmethod + def _get_graph_name(plot): + return plot.title + + +class BokehPlotter(Plotter): + """Use Bokeh library to plot Relay IR.""" + + def __init__(self): + self._name_to_graph = {} + + def create_graph(self, name): + if name in self._name_to_graph: + _LOGGER.warning("Graph name %s exists. ") + else: + self._name_to_graph[name] = BokehGraph() + return self._name_to_graph[name] + + def render(self, filename): + + if filename.endswith(".html"): + graph_name = os.path.splitext(os.path.basename(filename))[0] + else: + graph_name = filename + filename = "{}.html".format(filename) + + dom_list = [] + for name, graph in self._name_to_graph.items(): + plot = Plot( + title=name, + width=1600, + height=900, + align="center", + margin=(0, 0, 0, 70), + ) + + dom = graph.render(plot) + dom_list.append(dom) + + self._save_html(filename, column(*dom_list)) + + def _save_html(self, filename, layout_dom): + + output_file(filename, title=filename) + + template = """ + {% block postamble %} + + {% endblock %} + """ + + save(layout_dom, filename=filename, title=filename, template=template) + diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py new file mode 100644 index 000000000000..3a6f7ca95768 --- /dev/null +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Visualize Relay IR in AST text-form""" + +from collections import deque + +from pyparsing import line + +from .plotter import ( + Plotter, + Graph, +) + +import tvm +from tvm import relay + +def render_cb(graph, node_to_id, relay_param): + # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + unknown_type = "unknown" + for node, node_id in node_to_id.items(): + if isinstance(node, relay.Function): + graph.node(node_id, f"Func", str(node.params)) + graph.edge(node_to_id[node.body], node_id) + elif isinstance(node, relay.Var): + name_hint = node.name_hint + node_detail = "" + node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" + if node.type_annotation is not None: + if hasattr(node.type_annotation, "shape"): + shape = tuple(map(int, node.type_annotation.shape)) + dtype = node.type_annotation.dtype + node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( + name_hint, shape, dtype + ) + else: + node_detail = "name_hint: {}\ntype_annotation: {}".format( + name_hint, node.type_annotation + ) + graph.node(node_id, node_type, node_detail) + elif isinstance(node, relay.GlobalVar): + # Dont render this because GlobalVar is put to another graph. + pass + elif isinstance(node, relay.Tuple): + graph.node(node_id, "Tuple", "") + for field in node.fields: + graph.edge(node_to_id[field], node_id) + elif isinstance(node, relay.expr.Constant): + node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) + graph.node(node_id, "Const", str(node)) + elif isinstance(node, relay.expr.Call): + op_name = unknown_type + node_details = [] + if isinstance(node.op, tvm.ir.Op): + op_name = node.op.name + if node.attrs: + node_details = [ + "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() + ] + elif isinstance(node.op, relay.Function): + func_attrs = node.op.attrs + op_name = "Anonymous Func" + if func_attrs: + node_details = [ + "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() + ] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] + elif isinstance(node.op, relay.GlobalVar): + op_name = "GlobalVar" + node_details = [f"GlobalVar.name_hint: {node.op.name_hint}"] + else: + op_name = str(type(node.op)).split(".")[-1].split("'")[0] + + graph.node(node_id, f"Call {op_name}", "\n".join(node_details)) + args = [node_to_id[arg] for arg in node.args] + for arg in args: + graph.edge(arg, node_id) + elif isinstance(node, relay.expr.TupleGetItem): + graph.node(node_id, "TupleGetItem", "idx: {}".format(node.index)) + graph.edge(node_to_id[node.tuple_value], node_id) + elif isinstance(node, tvm.ir.Op): + pass + elif isinstance(node, relay.Let): + graph.node(node_id, "Let", "") + graph.edge(node_to_id[node.value], node_id) + graph.edge(node_id, node_to_id[node.var]) + else: + unknown_info = "Unknown node: {}".format(type(node)) + graph.node(node_id, unknown_type, unknown_info) + + +class Node: + def __init__(self, node_type, other_info): + self.type = node_type + self.other_info = other_info.replace("\n", ", ") + + +class TermGraph(Graph): + + def __init__(self, name): + # node_id: [ connected node_id] + self._name = name + self._graph = {} + self._id_to_node = {} + # reversed post order + self._node_ids_rpo = deque() + + def node(self, node_id, node_type, node_detail): + # actually we just need the last one. + self._node_ids_rpo.appendleft(node_id) + + if node_id not in self._graph: + self._graph[node_id] = [] + + node = Node(node_type, node_detail) + self._id_to_node[node_id] = node + + def edge(self, id_start, id_end): + # want reserved post-order + if id_end in self._graph: + self._graph[id_end].append(id_start) + else: + self._graph[id_end] = [id_start] + + def render(self): + + lines = [] + + def gen_line(indent, n_id): + conn_symbol = "|--" + last_idx = len(lines) + len(self._graph[n_id]) - 1 + for next_n_id in self._graph[n_id]: + node = self._id_to_node[next_n_id] + lines.append(f"{indent}{conn_symbol}{node.type} {node.other_info}") + gen_line(f" {indent}", next_n_id) + if len(self._graph[n_id]): + lines[last_idx] = lines[last_idx].replace("|", "`") + + first_node_id = self._node_ids_rpo[0] + node = self._id_to_node[first_node_id] + lines.append(f"@{self._name}({node.other_info})") + gen_line(" ", first_node_id) + + return "\n".join(lines) + + +class TermPlotter(Plotter): + + def __init__(self): + self._name_to_graph = {} + + def create_graph(self, name): + self._name_to_graph[name] = TermGraph(name) + return self._name_to_graph[name] + + def render(self, filename): + # if filename == "stdio", print to terminal. + # Otherwise, print to the file? + lines = [] + for name in self._name_to_graph: + lines.append(self._name_to_graph[name].render()) + print("\n".join(lines)) diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py new file mode 100644 index 000000000000..645860525be3 --- /dev/null +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Abstract class for plotters.""" +import abc + + +class Graph(abc.ABC): + """Abstract class for graph. + + Implement this interface for various graph libraries. + """ + + @abc.abstractmethod + def node(self, node_id, node_type, node_detail): + """Add a node to the underlying graph. + + Parameters + ---------- + node_id : object + Serve as the ID to the node. + + node_type : string + the type of the node. + + node_detail : string + the description of the node. + """ + + @abc.abstractmethod + def edge(self, id_start, id_end): + """Add an edge to the underlying graph. + + Parameters + ---------- + id_start : object + the ID to the starting node. + + id_start : object + the ID to the ending node. + """ + +class Plotter(abc.ABC): + """Abstract class for plotters. + + Implement this interface for various graph libraries. + """ + + @abc.abstractclassmethod + def create_graph(name): + """Create a graph + + Parameters + ---------- + name : string, the name of the graph + + Return + ------ + Graph instance. + """ + + @abc.abstractmethod + def render(self, filename): + """Render the graph as a file. + + Parameters + ---------- + filename : string + """ From 54bb86c3a13136cbc36b03af0edae297a2a90ba3 Mon Sep 17 00:00:00 2001 From: kueitang Date: Tue, 31 Aug 2021 03:07:13 +0800 Subject: [PATCH 04/16] Integrated Version of relay Vizualizer * Add the TERMINAL backend * Integrated two backends * Add a RenderCallback class for reusable nodes type processing --- python/tvm/contrib/relay_viz/__init__.py | 56 ++++++-- python/tvm/contrib/relay_viz/_bokeh.py | 96 +------------ python/tvm/contrib/relay_viz/_terminal.py | 135 +++++++----------- python/tvm/contrib/relay_viz/plotter.py | 2 +- .../tvm/contrib/relay_viz/render_callback.py | 126 ++++++++++++++++ 5 files changed, 228 insertions(+), 187 deletions(-) create mode 100644 python/tvm/contrib/relay_viz/render_callback.py diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index eaaf86441f2e..1a2547d67b54 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -15,11 +15,14 @@ # specific language governing permissions and limitations # under the License. """Relay IR Visualizer""" +import logging import copy from tvm import relay +from enum import Enum +_LOGGER = logging.getLogger(__name__) -class PlotterBackend: +class PlotterBackend(Enum): """Enumeration for available plotters.""" BOKEH = "bokeh" @@ -44,10 +47,11 @@ def __init__( The backend of plotting. Default "bokeh" """ - self._plotter, self._render_cb = get_plotter_and_render_cb(backend) + self._plotter, self.render_rules = get_plotter_and_render_cb(backend) self._relay_param = relay_param if relay_param is not None else {} # This field is used for book-keeping for each graph. self._node_to_id = {} + # self.test_type = set() global_vars = relay_mod.get_global_vars() graph_names = [] @@ -75,18 +79,50 @@ def _traverse_expr(self, node): return self._node_to_id[node] = len(self._node_to_id) + def _render_cb(self, graph, node_to_id, relay_param): + """a callback to Add nodes and edges to the graph. + + Parameters + ---------- + graph : class plotter.Graph + + node_to_id : Dict[relay.expr, int] + + relay_param : Dict[string, NDarray] + """ + # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + unknown_type = "unknown" + for node, node_id in node_to_id.items(): + # self.test_type.add(type(node)) + if type(node) in self.render_rules: + graph_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) + if graph_info: + graph.node(*graph_info) + while edge_info: + this_edge = edge_info.pop(0) + graph.edge(*this_edge) + else: + unknown_info = "Unknown node: {}".format(type(node)) + _LOGGER.warning(unknown_info) + graph.node(node_id, unknown_type, unknown_info) + def render(self, filename): return self._plotter.render(filename=filename) def get_plotter_and_render_cb(backend): - if backend == PlotterBackend.BOKEH: - from ._bokeh import BokehPlotter, relay_render_cb # pylint: disable=import-outside-toplevel - - return BokehPlotter(), relay_render_cb - if backend == PlotterBackend.TERMINAL: - from ._terminal import TermPlotter, render_cb - - return TermPlotter(), render_cb + if backend in PlotterBackend: + if backend == PlotterBackend.BOKEH: + from ._bokeh import BokehPlotter, BokehRenderCallback # pylint: disable=import-outside-toplevel + plotter = BokehPlotter() + render = BokehRenderCallback() + + elif backend == PlotterBackend.TERMINAL: + from ._terminal import TermPlotter, TermRenderCallback + plotter = TermPlotter() + render = TermRenderCallback() + + render_rules = render.get_rules() + return plotter, render_rules raise ValueError("Unknown plotter backend {}".format(backend)) diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index afbda2d47c32..e60b8f3d27e6 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -55,101 +55,11 @@ _LOGGER = logging.getLogger(__name__) - -def relay_render_cb(graph, node_to_id, relay_param): - """a callback to Add nodes and edges to the graph. - - Parameters - ---------- - graph : class plotter.Graph - - node_to_id : Dict[relay.expr, int] - - relay_param : Dict[string, NDarray] - """ - # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm - unknown_type = "unknown" - for node, node_id in node_to_id.items(): - if isinstance(node, relay.Function): - node_details = [] - func_attrs = node.attrs - if func_attrs: - node_details = [ - "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() - ] - - graph.node(node_id, f"Func", "\n".join(node_details)) - graph.edge(node_to_id[node.body], node_id) - elif isinstance(node, relay.Var): - name_hint = node.name_hint - node_detail = "" - node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" - if node.type_annotation is not None: - if hasattr(node.type_annotation, "shape"): - shape = tuple(map(int, node.type_annotation.shape)) - dtype = node.type_annotation.dtype - node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( - name_hint, shape, dtype - ) - else: - node_detail = "name_hint: {}\ntype_annotation: {}".format( - name_hint, node.type_annotation - ) - graph.node(node_id, node_type, node_detail) - elif isinstance(node, relay.GlobalVar): - # Dont render this because GlobalVar is put to another graph. - pass - elif isinstance(node, relay.Tuple): - graph.node(node_id, "Tuple", "") - for field in node.fields: - graph.edge(node_to_id[field], node_id) - elif isinstance(node, relay.expr.Constant): - node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - graph.node(node_id, "Const", node_detail) - elif isinstance(node, relay.expr.Call): - op_name = unknown_type - node_details = [] - if isinstance(node.op, tvm.ir.Op): - op_name = node.op.name - if node.attrs: - node_details = [ - "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() - ] - elif isinstance(node.op, relay.Function): - func_attrs = node.op.attrs - op_name = "Anonymous Func" - if func_attrs: - node_details = [ - "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() - ] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - op_name = func_attrs["Composite"] - elif isinstance(node.op, relay.GlobalVar): - op_name = "GlobalVar" - node_details = [f"name_hint: {node.op.name_hint}"] - else: - op_name = str(type(node.op)).split(".")[-1].split("'")[0] - - graph.node(node_id, op_name, "\n".join(node_details)) - args = [node_to_id[arg] for arg in node.args] - for arg in args: - graph.edge(arg, node_id) - elif isinstance(node, relay.expr.TupleGetItem): - graph.node(node_id, "TupleGetItem", "idx: {}".format(node.index)) - graph.edge(node_to_id[node.tuple_value], node_id) - elif isinstance(node, tvm.ir.Op): - pass - elif isinstance(node, relay.Let): - graph.node(node_id, "Let", "") - graph.edge(node_to_id[node.value], node_id) - graph.edge(node_id, node_to_id[node.var]) - else: - unknown_info = "Unknown node: {}".format(type(node)) - _LOGGER.warning(unknown_info) - graph.node(node_id, unknown_type, unknown_info) +from .render_callback import RenderCallback +class BokehRenderCallback(RenderCallback): + pass class NodeDescriptor: """Descriptor used by Bokeh plotter.""" diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 3a6f7ca95768..2e33f5e900d2 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -24,84 +24,53 @@ Plotter, Graph, ) +import functools import tvm from tvm import relay +from .render_callback import RenderCallback -def render_cb(graph, node_to_id, relay_param): - # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm - unknown_type = "unknown" - for node, node_id in node_to_id.items(): - if isinstance(node, relay.Function): - graph.node(node_id, f"Func", str(node.params)) - graph.edge(node_to_id[node.body], node_id) - elif isinstance(node, relay.Var): - name_hint = node.name_hint - node_detail = "" - node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" - if node.type_annotation is not None: - if hasattr(node.type_annotation, "shape"): - shape = tuple(map(int, node.type_annotation.shape)) - dtype = node.type_annotation.dtype - node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( - name_hint, shape, dtype - ) - else: - node_detail = "name_hint: {}\ntype_annotation: {}".format( - name_hint, node.type_annotation - ) - graph.node(node_id, node_type, node_detail) - elif isinstance(node, relay.GlobalVar): - # Dont render this because GlobalVar is put to another graph. - pass - elif isinstance(node, relay.Tuple): - graph.node(node_id, "Tuple", "") - for field in node.fields: - graph.edge(node_to_id[field], node_id) - elif isinstance(node, relay.expr.Constant): - node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - graph.node(node_id, "Const", str(node)) - elif isinstance(node, relay.expr.Call): - op_name = unknown_type - node_details = [] - if isinstance(node.op, tvm.ir.Op): - op_name = node.op.name - if node.attrs: - node_details = [ - "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() - ] - elif isinstance(node.op, relay.Function): - func_attrs = node.op.attrs - op_name = "Anonymous Func" - if func_attrs: - node_details = [ - "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() - ] - # "Composite" might from relay.transform.MergeComposite - if "Composite" in func_attrs.keys(): - op_name = func_attrs["Composite"] - elif isinstance(node.op, relay.GlobalVar): - op_name = "GlobalVar" - node_details = [f"GlobalVar.name_hint: {node.op.name_hint}"] - else: - op_name = str(type(node.op)).split(".")[-1].split("'")[0] - - graph.node(node_id, f"Call {op_name}", "\n".join(node_details)) - args = [node_to_id[arg] for arg in node.args] - for arg in args: - graph.edge(arg, node_id) - elif isinstance(node, relay.expr.TupleGetItem): - graph.node(node_id, "TupleGetItem", "idx: {}".format(node.index)) - graph.edge(node_to_id[node.tuple_value], node_id) - elif isinstance(node, tvm.ir.Op): - pass - elif isinstance(node, relay.Let): - graph.node(node_id, "Let", "") - graph.edge(node_to_id[node.value], node_id) - graph.edge(node_id, node_to_id[node.var]) - else: - unknown_info = "Unknown node: {}".format(type(node)) - graph.node(node_id, unknown_type, unknown_info) +class TermRenderCallback(RenderCallback): + def __init__(self): + super().__init__() + + def Call_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, f"Call", ""] + edge_info = [[node_to_id[node.op], node_id]] + args = [node_to_id[arg] for arg in node.args] + for arg in args: + edge_info.append([arg, node_id]) + return graph_info, edge_info + + def Let_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "Let", "(var, val, body)"] + edge_info = [[node_to_id[node.var], node_id]] + edge_info.append([node_to_id[node.value], node_id]) + edge_info.append([node_to_id[node.body], node_id]) + return graph_info, edge_info + + def Global_var_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "GlobalVar", node.name_hint] + edge_info = [] + return graph_info, edge_info + + def If_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "If", "(cond, true, false)"] + edge_info = [[node_to_id[node.cond], node_id]] + edge_info.append([node_to_id[node.true_branch], node_id]) + edge_info.append([node_to_id[node.false_branch], node_id]) + return graph_info, edge_info + + def Op_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + op_name = node.name + graph_info = [node_id, op_name, ""] + edge_info = [] + return graph_info, edge_info class Node: @@ -138,23 +107,23 @@ def edge(self, id_start, id_end): self._graph[id_end] = [id_start] def render(self): - lines = [] + @functools.lru_cache() def gen_line(indent, n_id): - conn_symbol = "|--" - last_idx = len(lines) + len(self._graph[n_id]) - 1 - for next_n_id in self._graph[n_id]: + conn_symbol = ["|--", "`--"] + last = len(self._graph[n_id]) - 1 + for i, next_n_id in enumerate(self._graph[n_id]): node = self._id_to_node[next_n_id] - lines.append(f"{indent}{conn_symbol}{node.type} {node.other_info}") - gen_line(f" {indent}", next_n_id) - if len(self._graph[n_id]): - lines[last_idx] = lines[last_idx].replace("|", "`") + lines.append(f"{indent}{conn_symbol[i==last]}{node.type} {node.other_info}") + next_indent = indent + next_indent += " " if (i == last) else "| " + gen_line(next_indent, next_n_id) first_node_id = self._node_ids_rpo[0] node = self._id_to_node[first_node_id] lines.append(f"@{self._name}({node.other_info})") - gen_line(" ", first_node_id) + gen_line("", first_node_id) return "\n".join(lines) diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 645860525be3..25df4d80c9ba 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -49,7 +49,7 @@ def edge(self, id_start, id_end): id_start : object the ID to the starting node. - id_start : object + id_end : object the ID to the ending node. """ diff --git a/python/tvm/contrib/relay_viz/render_callback.py b/python/tvm/contrib/relay_viz/render_callback.py new file mode 100644 index 000000000000..e86464144be1 --- /dev/null +++ b/python/tvm/contrib/relay_viz/render_callback.py @@ -0,0 +1,126 @@ + +import tvm +from tvm import relay + +unknown_type = "unknown" + +class RenderCallback(): + '''This is the default callback rules, which is also the _bokeh backend drawing way''' + def __init__(self): + self.render_rules = {} + self.build_rules() + + def Var_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + name_hint = node.name_hint + node_detail = "" + node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" + if node.type_annotation is not None: + if hasattr(node.type_annotation, "shape"): + shape = tuple(map(int, node.type_annotation.shape)) + dtype = node.type_annotation.dtype + node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( + name_hint, shape, dtype + ) + else: + node_detail = "name_hint: {}\ntype_annotation: {}".format( + name_hint, node.type_annotation + ) + graph_info = [node_id, node_type, node_detail] + edge_info = [] + return graph_info, edge_info + + def Function_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, f"Func", str(node.params)] + edge_info = [[node_to_id[node.body], node_id]] + return graph_info, edge_info + + def Call_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + op_name = unknown_type + node_details = [] + if isinstance(node.op, tvm.ir.Op): + op_name = node.op.name + if node.attrs: + node_details = [ + "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() + ] + elif isinstance(node.op, relay.Function): + func_attrs = node.op.attrs + op_name = "Anonymous Func" + if func_attrs: + node_details = [ + "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() + ] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + op_name = func_attrs["Composite"] + elif isinstance(node.op, relay.GlobalVar): + op_name = "GlobalVar" + node_details = [f"GlobalVar.name_hint: {node.op.name_hint}"] + else: + op_name = str(type(node.op)).split(".")[-1].split("'")[0] + + graph_info = [node_id, f"Call {op_name}", "\n".join(node_details)] + args = [node_to_id[arg] for arg in node.args] + edge_info = [[arg, node_id] for arg in args] + return graph_info, edge_info + + def Let_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "Let", ""] + edge_info = [[node_to_id[node.value], node_id]] + edge_info.append([node_id, node_to_id[node.var]]) + return graph_info, edge_info + + def Global_var_node(self, node, relay_param, node_to_id): + graph_info = [] + edge_info = [] + return graph_info, edge_info + + def If_node(self, node, relay_param, node_to_id): + graph_info = [] + edge_info = [] + return graph_info, edge_info + + def Tuple_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "Tuple", ""] + edge_info = [[node_to_id[field], node_id] for field in node.fields] + return graph_info, edge_info + + def TupleGetItem_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + graph_info = [node_id, "TupleGetItem", "idx: {}".format(node.index)] + edge_info = [[node_to_id[node.tuple_value], node_id]] + return graph_info, edge_info + + def Constant_node(self, node, relay_param, node_to_id): + node_id = node_to_id[node] + node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) + graph_info = [node_id, "Const", str(node)] + edge_info = [] + return graph_info, edge_info + + def Op_node(self, node, relay_param, node_to_id): + graph_info = [] + edge_info = [] + return graph_info, edge_info + + def build_rules(self): + self.render_rules = { + tvm.relay.Function : self.Function_node, + tvm.relay.expr.Call : self.Call_node, + tvm.relay.expr.Let : self.Let_node, + tvm.relay.expr.Var : self.Var_node, + tvm.relay.expr.GlobalVar : self.Global_var_node, + tvm.relay.expr.If : self.If_node, + tvm.relay.expr.Tuple : self.Tuple_node, + tvm.relay.expr.TupleGetItem : self.TupleGetItem_node, + tvm.relay.expr.Constant : self.Constant_node, + tvm.ir.Op : self.Op_node, + } + + def get_rules(self): + return self.render_rules From d9ec581c0a506780146e7acd2526a6550d69e055 Mon Sep 17 00:00:00 2001 From: kueitang Date: Mon, 13 Sep 2021 20:10:52 +0800 Subject: [PATCH 05/16] - Make users can defined their own backend from input - fix some typos --- python/tvm/contrib/relay_viz/README.md | 1 - python/tvm/contrib/relay_viz/__init__.py | 48 ++++++++++++++++-------- python/tvm/contrib/relay_viz/plotter.py | 3 +- 3 files changed, 35 insertions(+), 17 deletions(-) diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md index 5f5eb135fe80..5a189f85cbb3 100644 --- a/python/tvm/contrib/relay_viz/README.md +++ b/python/tvm/contrib/relay_viz/README.md @@ -24,7 +24,6 @@ This tool target to visualize Relay IR. 1. [Requirement](#Requirement) 2. [Usage](#Usage) 3. [Credits](#Credits) -3. [TODO](#TODO) ## Requirement diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 1a2547d67b54..24cac90a69c3 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -19,9 +19,13 @@ import copy from tvm import relay from enum import Enum +from .plotter import Plotter +from .render_callback import RenderCallback + _LOGGER = logging.getLogger(__name__) + class PlotterBackend(Enum): """Enumeration for available plotters.""" @@ -32,9 +36,7 @@ class PlotterBackend(Enum): class RelayVisualizer: """Relay IR Visualizer""" - def __init__( - self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH - ): + def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH): """Visualize Relay IR. Parameters @@ -43,15 +45,15 @@ def __init__( Relay IR module relay_param: dict Relay parameter dictionary - backend: PlotterBackend. - The backend of plotting. Default "bokeh" + backend: PlotterBackend or a tuple + PlotterBackend: The backend of plotting. Default "bokeh" + tuple: A tuple with two arguments. First is user-defined Plotter, the second is user-defined RenderCallback """ - self._plotter, self.render_rules = get_plotter_and_render_cb(backend) + self._plotter, self._render_rules = get_plotter_and_render_rules(backend) self._relay_param = relay_param if relay_param is not None else {} # This field is used for book-keeping for each graph. self._node_to_id = {} - # self.test_type = set() global_vars = relay_mod.get_global_vars() graph_names = [] @@ -93,14 +95,14 @@ def _render_cb(self, graph, node_to_id, relay_param): # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm unknown_type = "unknown" for node, node_id in node_to_id.items(): - # self.test_type.add(type(node)) - if type(node) in self.render_rules: - graph_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) + if type(node) in self._render_rules: + graph_info, edge_info = self._render_rules[type(node)]( + node, relay_param, node_to_id + ) if graph_info: graph.node(*graph_info) - while edge_info: - this_edge = edge_info.pop(0) - graph.edge(*this_edge) + for edge in edge_info: + graph.edge(*edge) else: unknown_info = "Unknown node: {}".format(type(node)) _LOGGER.warning(unknown_info) @@ -110,15 +112,31 @@ def render(self, filename): return self._plotter.render(filename=filename) -def get_plotter_and_render_cb(backend): +def get_plotter_and_render_rules(backend): + if type(backend) is tuple and len(backend) == 2: + if not isinstance(backend[0], Plotter): + raise ValueError("First elemnet of the backend should be a plotter") + plotter = backend[0] + if not isinstance(backend[1], RenderCallback): + raise ValueError("Second elemnet of the backend should be a callback") + render = backend[1] + render_rules = render.get_rules() + return plotter, render_rules + if backend in PlotterBackend: if backend == PlotterBackend.BOKEH: - from ._bokeh import BokehPlotter, BokehRenderCallback # pylint: disable=import-outside-toplevel + from ._bokeh import ( + BokehPlotter, + BokehRenderCallback, + ) # pylint: disable=import-outside-toplevel + plotter = BokehPlotter() render = BokehRenderCallback() elif backend == PlotterBackend.TERMINAL: from ._terminal import TermPlotter, TermRenderCallback + + print(isinstance(TermPlotter(), Plotter)) plotter = TermPlotter() render = TermRenderCallback() diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 25df4d80c9ba..013df7aa05a4 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -53,13 +53,14 @@ def edge(self, id_start, id_end): the ID to the ending node. """ + class Plotter(abc.ABC): """Abstract class for plotters. Implement this interface for various graph libraries. """ - @abc.abstractclassmethod + @abc.abstractmethod def create_graph(name): """Create a graph From ab3c56c61d2e5a5983a7feee0d1a6fd35b0defe1 Mon Sep 17 00:00:00 2001 From: kueitang Date: Tue, 14 Sep 2021 09:17:55 +0800 Subject: [PATCH 06/16] - Fix format typos --- python/tvm/contrib/relay_viz/_bokeh.py | 6 ++- python/tvm/contrib/relay_viz/_terminal.py | 5 +- .../tvm/contrib/relay_viz/render_callback.py | 52 ++++++++++++------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index e60b8f3d27e6..c208da3494dc 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -61,6 +61,7 @@ class BokehRenderCallback(RenderCallback): pass + class NodeDescriptor: """Descriptor used by Bokeh plotter.""" @@ -376,7 +377,9 @@ def get_scatter_loc(x_start, x_end, y_start, y_end, end_node): id_start = edge.get_source() id_end = edge.get_destination() x_pts, y_pts = shaper.get_edge_path(id_start, id_end) - x_loc, y_loc, angle = get_scatter_loc(x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end) + x_loc, y_loc, angle = get_scatter_loc( + x_pts[-2], x_pts[-1], y_pts[-2], y_pts[-1], id_end + ) scatter_source["angle"].append(angle) scatter_source["x"].append(x_loc) scatter_source["y"].append(y_loc) @@ -487,4 +490,3 @@ def _save_html(self, filename, layout_dom): """ save(layout_dom, filename=filename, title=filename, template=template) - diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 2e33f5e900d2..2fc7a16862df 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -30,6 +30,7 @@ from tvm import relay from .render_callback import RenderCallback + class TermRenderCallback(RenderCallback): def __init__(self): super().__init__() @@ -54,7 +55,7 @@ def Let_node(self, node, relay_param, node_to_id): def Global_var_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] graph_info = [node_id, "GlobalVar", node.name_hint] - edge_info = [] + edge_info = [] return graph_info, edge_info def If_node(self, node, relay_param, node_to_id): @@ -80,7 +81,6 @@ def __init__(self, node_type, other_info): class TermGraph(Graph): - def __init__(self, name): # node_id: [ connected node_id] self._name = name @@ -129,7 +129,6 @@ def gen_line(indent, n_id): class TermPlotter(Plotter): - def __init__(self): self._name_to_graph = {} diff --git a/python/tvm/contrib/relay_viz/render_callback.py b/python/tvm/contrib/relay_viz/render_callback.py index e86464144be1..58b4cd04304c 100644 --- a/python/tvm/contrib/relay_viz/render_callback.py +++ b/python/tvm/contrib/relay_viz/render_callback.py @@ -1,11 +1,29 @@ - +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Default render callback rules""" import tvm from tvm import relay unknown_type = "unknown" -class RenderCallback(): - '''This is the default callback rules, which is also the _bokeh backend drawing way''' + +class RenderCallback: + """This is the default callback rules, which is also the _bokeh backend drawing way""" + def __init__(self): self.render_rules = {} self.build_rules() @@ -19,9 +37,7 @@ def Var_node(self, node, relay_param, node_to_id): if hasattr(node.type_annotation, "shape"): shape = tuple(map(int, node.type_annotation.shape)) dtype = node.type_annotation.dtype - node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format( - name_hint, shape, dtype - ) + node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format(name_hint, shape, dtype) else: node_detail = "name_hint: {}\ntype_annotation: {}".format( name_hint, node.type_annotation @@ -76,12 +92,12 @@ def Let_node(self, node, relay_param, node_to_id): def Global_var_node(self, node, relay_param, node_to_id): graph_info = [] - edge_info = [] + edge_info = [] return graph_info, edge_info def If_node(self, node, relay_param, node_to_id): graph_info = [] - edge_info = [] + edge_info = [] return graph_info, edge_info def Tuple_node(self, node, relay_param, node_to_id): @@ -110,16 +126,16 @@ def Op_node(self, node, relay_param, node_to_id): def build_rules(self): self.render_rules = { - tvm.relay.Function : self.Function_node, - tvm.relay.expr.Call : self.Call_node, - tvm.relay.expr.Let : self.Let_node, - tvm.relay.expr.Var : self.Var_node, - tvm.relay.expr.GlobalVar : self.Global_var_node, - tvm.relay.expr.If : self.If_node, - tvm.relay.expr.Tuple : self.Tuple_node, - tvm.relay.expr.TupleGetItem : self.TupleGetItem_node, - tvm.relay.expr.Constant : self.Constant_node, - tvm.ir.Op : self.Op_node, + tvm.relay.Function: self.Function_node, + tvm.relay.expr.Call: self.Call_node, + tvm.relay.expr.Let: self.Let_node, + tvm.relay.expr.Var: self.Var_node, + tvm.relay.expr.GlobalVar: self.Global_var_node, + tvm.relay.expr.If: self.If_node, + tvm.relay.expr.Tuple: self.Tuple_node, + tvm.relay.expr.TupleGetItem: self.TupleGetItem_node, + tvm.relay.expr.Constant: self.Constant_node, + tvm.ir.Op: self.Op_node, } def get_rules(self): From a7a44648e5b23c58f8bfda8afcde97baa677d4ec Mon Sep 17 00:00:00 2001 From: kueitang Date: Tue, 14 Sep 2021 14:12:10 +0800 Subject: [PATCH 07/16] Fix more format checks --- python/tvm/contrib/relay_viz/__init__.py | 28 ++++++-- python/tvm/contrib/relay_viz/_bokeh.py | 14 +--- python/tvm/contrib/relay_viz/_terminal.py | 28 ++++---- python/tvm/contrib/relay_viz/plotter.py | 2 +- .../tvm/contrib/relay_viz/render_callback.py | 64 +++++++++---------- 5 files changed, 70 insertions(+), 66 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 24cac90a69c3..cfb48390b100 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -17,8 +17,8 @@ """Relay IR Visualizer""" import logging import copy -from tvm import relay from enum import Enum +from tvm import relay from .plotter import Plotter from .render_callback import RenderCallback @@ -47,7 +47,8 @@ def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH): Relay parameter dictionary backend: PlotterBackend or a tuple PlotterBackend: The backend of plotting. Default "bokeh" - tuple: A tuple with two arguments. First is user-defined Plotter, the second is user-defined RenderCallback + Tuple: A tuple with two arguments. First is user-defined Plotter, \ + the second is user-defined RenderCallback """ self._plotter, self._render_rules = get_plotter_and_render_rules(backend) @@ -95,7 +96,7 @@ def _render_cb(self, graph, node_to_id, relay_param): # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm unknown_type = "unknown" for node, node_id in node_to_id.items(): - if type(node) in self._render_rules: + if type(node) in self._render_rules: # pylint: disable=unidiomatic-typecheck graph_info, edge_info = self._render_rules[type(node)]( node, relay_param, node_to_id ) @@ -113,7 +114,16 @@ def render(self, filename): def get_plotter_and_render_rules(backend): - if type(backend) is tuple and len(backend) == 2: + """Specify the Plottor and its render rules + + Parameters + ---------- + backend: PlotterBackend or a tuple + PlotterBackend: The backend of plotting. Default "bokeh" + Tuple: A tuple with two arguments. First is user-defined Plotter, \ + the second is user-defined RenderCallback + """ + if type(backend) is tuple and len(backend) == 2: # pylint: disable=unidiomatic-typecheck if not isinstance(backend[0], Plotter): raise ValueError("First elemnet of the backend should be a plotter") plotter = backend[0] @@ -125,18 +135,22 @@ def get_plotter_and_render_rules(backend): if backend in PlotterBackend: if backend == PlotterBackend.BOKEH: + # pylint: disable=import-outside-toplevel from ._bokeh import ( BokehPlotter, BokehRenderCallback, - ) # pylint: disable=import-outside-toplevel + ) plotter = BokehPlotter() render = BokehRenderCallback() elif backend == PlotterBackend.TERMINAL: - from ._terminal import TermPlotter, TermRenderCallback + # pylint: disable=import-outside-toplevel + from ._terminal import ( + TermPlotter, + TermRenderCallback, + ) - print(isinstance(TermPlotter(), Plotter)) plotter = TermPlotter() render = TermRenderCallback() diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index c208da3494dc..0cff967de860 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Bokeh backend for Relay IR Visualizer.""" -import os import html import logging import functools @@ -50,13 +49,10 @@ Graph, ) -import tvm -from tvm import relay +from .render_callback import RenderCallback # pylint: disable=import-outside-toplevel _LOGGER = logging.getLogger(__name__) -from .render_callback import RenderCallback - class BokehRenderCallback(RenderCallback): pass @@ -209,7 +205,7 @@ def edge(self, id_start, id_end): self._pydot_digraph.add_edge(pydot.Edge(id_start, id_end)) def render(self, plot): - + """To draw a Bokeh Graph""" shaper = GraphShaper( self._pydot_digraph, prog="dot", @@ -453,11 +449,7 @@ def create_graph(self, name): return self._name_to_graph[name] def render(self, filename): - - if filename.endswith(".html"): - graph_name = os.path.splitext(os.path.basename(filename))[0] - else: - graph_name = filename + if not filename.endswith(".html"): filename = "{}.html".format(filename) dom_list = [] diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 2fc7a16862df..2732ee4c54e0 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -17,34 +17,29 @@ """Visualize Relay IR in AST text-form""" from collections import deque - -from pyparsing import line +import functools from .plotter import ( Plotter, Graph, ) -import functools -import tvm -from tvm import relay from .render_callback import RenderCallback class TermRenderCallback(RenderCallback): - def __init__(self): - super().__init__() + """Terminal render callback""" - def Call_node(self, node, relay_param, node_to_id): + def call_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] - graph_info = [node_id, f"Call", ""] + graph_info = [node_id, "Call", ""] edge_info = [[node_to_id[node.op], node_id]] args = [node_to_id[arg] for arg in node.args] for arg in args: edge_info.append([arg, node_id]) return graph_info, edge_info - def Let_node(self, node, relay_param, node_to_id): + def let_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] graph_info = [node_id, "Let", "(var, val, body)"] edge_info = [[node_to_id[node.var], node_id]] @@ -52,13 +47,13 @@ def Let_node(self, node, relay_param, node_to_id): edge_info.append([node_to_id[node.body], node_id]) return graph_info, edge_info - def Global_var_node(self, node, relay_param, node_to_id): + def global_var_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] graph_info = [node_id, "GlobalVar", node.name_hint] edge_info = [] return graph_info, edge_info - def If_node(self, node, relay_param, node_to_id): + def if_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] graph_info = [node_id, "If", "(cond, true, false)"] edge_info = [[node_to_id[node.cond], node_id]] @@ -66,7 +61,7 @@ def If_node(self, node, relay_param, node_to_id): edge_info.append([node_to_id[node.false_branch], node_id]) return graph_info, edge_info - def Op_node(self, node, relay_param, node_to_id): + def op_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] op_name = node.name graph_info = [node_id, op_name, ""] @@ -81,6 +76,8 @@ def __init__(self, node_type, other_info): class TermGraph(Graph): + """Terminal plot for a relay IR Module""" + def __init__(self, name): # node_id: [ connected node_id] self._name = name @@ -107,6 +104,7 @@ def edge(self, id_start, id_end): self._graph[id_end] = [id_start] def render(self): + """To draw a terminal graph""" lines = [] @functools.lru_cache() @@ -129,6 +127,8 @@ def gen_line(indent, n_id): class TermPlotter(Plotter): + """Terminal plotter""" + def __init__(self): self._name_to_graph = {} @@ -138,7 +138,7 @@ def create_graph(self, name): def render(self, filename): # if filename == "stdio", print to terminal. - # Otherwise, print to the file? + # Otherwise, print to the file lines = [] for name in self._name_to_graph: lines.append(self._name_to_graph[name].render()) diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 013df7aa05a4..8e9dc49a29e0 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -61,7 +61,7 @@ class Plotter(abc.ABC): """ @abc.abstractmethod - def create_graph(name): + def create_graph(self, name): """Create a graph Parameters diff --git a/python/tvm/contrib/relay_viz/render_callback.py b/python/tvm/contrib/relay_viz/render_callback.py index 58b4cd04304c..a2661db6085f 100644 --- a/python/tvm/contrib/relay_viz/render_callback.py +++ b/python/tvm/contrib/relay_viz/render_callback.py @@ -18,7 +18,7 @@ import tvm from tvm import relay -unknown_type = "unknown" +UNKNOWN_TYPE = "unknown" class RenderCallback: @@ -28,7 +28,8 @@ def __init__(self): self.render_rules = {} self.build_rules() - def Var_node(self, node, relay_param, node_to_id): + def var_node(self, node, relay_param, node_to_id): + """Render rule for a relay var node""" node_id = node_to_id[node] name_hint = node.name_hint node_detail = "" @@ -46,96 +47,93 @@ def Var_node(self, node, relay_param, node_to_id): edge_info = [] return graph_info, edge_info - def Function_node(self, node, relay_param, node_to_id): + def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] - graph_info = [node_id, f"Func", str(node.params)] + graph_info = [node_id, "Func", str(node.params)] edge_info = [[node_to_id[node.body], node_id]] return graph_info, edge_info - def Call_node(self, node, relay_param, node_to_id): + def call_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + """Render rule for a relay call node""" node_id = node_to_id[node] - op_name = unknown_type - node_details = [] + op_name = UNKNOWN_TYPE + node_detail = [] if isinstance(node.op, tvm.ir.Op): op_name = node.op.name if node.attrs: - node_details = [ - "{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys() - ] + node_detail = ["{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys()] elif isinstance(node.op, relay.Function): func_attrs = node.op.attrs op_name = "Anonymous Func" if func_attrs: - node_details = [ - "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() - ] + node_detail = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): op_name = func_attrs["Composite"] elif isinstance(node.op, relay.GlobalVar): op_name = "GlobalVar" - node_details = [f"GlobalVar.name_hint: {node.op.name_hint}"] + node_detail = [f"GlobalVar.name_hint: {node.op.name_hint}"] else: op_name = str(type(node.op)).split(".")[-1].split("'")[0] - graph_info = [node_id, f"Call {op_name}", "\n".join(node_details)] + graph_info = [node_id, f"Call {op_name}", "\n".join(node_detail)] args = [node_to_id[arg] for arg in node.args] edge_info = [[arg, node_id] for arg in args] return graph_info, edge_info - def Let_node(self, node, relay_param, node_to_id): + def let_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] graph_info = [node_id, "Let", ""] edge_info = [[node_to_id[node.value], node_id]] edge_info.append([node_id, node_to_id[node.var]]) return graph_info, edge_info - def Global_var_node(self, node, relay_param, node_to_id): + def global_var_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument graph_info = [] edge_info = [] return graph_info, edge_info - def If_node(self, node, relay_param, node_to_id): + def if_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument graph_info = [] edge_info = [] return graph_info, edge_info - def Tuple_node(self, node, relay_param, node_to_id): + def tuple_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] graph_info = [node_id, "Tuple", ""] edge_info = [[node_to_id[field], node_id] for field in node.fields] return graph_info, edge_info - def TupleGetItem_node(self, node, relay_param, node_to_id): + def tuple_get_item_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] graph_info = [node_id, "TupleGetItem", "idx: {}".format(node.index)] edge_info = [[node_to_id[node.tuple_value], node_id]] return graph_info, edge_info - def Constant_node(self, node, relay_param, node_to_id): + def constant_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - graph_info = [node_id, "Const", str(node)] + graph_info = [node_id, "Const", "\n".join(node_detail)] edge_info = [] return graph_info, edge_info - def Op_node(self, node, relay_param, node_to_id): + def op_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument graph_info = [] edge_info = [] return graph_info, edge_info def build_rules(self): self.render_rules = { - tvm.relay.Function: self.Function_node, - tvm.relay.expr.Call: self.Call_node, - tvm.relay.expr.Let: self.Let_node, - tvm.relay.expr.Var: self.Var_node, - tvm.relay.expr.GlobalVar: self.Global_var_node, - tvm.relay.expr.If: self.If_node, - tvm.relay.expr.Tuple: self.Tuple_node, - tvm.relay.expr.TupleGetItem: self.TupleGetItem_node, - tvm.relay.expr.Constant: self.Constant_node, - tvm.ir.Op: self.Op_node, + tvm.relay.Function: self.function_node, + tvm.relay.expr.Call: self.call_node, + tvm.relay.expr.Let: self.let_node, + tvm.relay.expr.Var: self.var_node, + tvm.relay.expr.GlobalVar: self.global_var_node, + tvm.relay.expr.If: self.if_node, + tvm.relay.expr.Tuple: self.tuple_node, + tvm.relay.expr.TupleGetItem: self.tuple_get_item_node, + tvm.relay.expr.Constant: self.constant_node, + tvm.ir.Op: self.op_node, } def get_rules(self): From 4b184f5d37da963b7dcd770544820416d6c8b64f Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 26 Sep 2021 14:20:43 +0000 Subject: [PATCH 08/16] address feedbacks --- python/tvm/contrib/relay_viz/__init__.py | 78 ++++++++++-------- python/tvm/contrib/relay_viz/_bokeh.py | 29 +++++-- python/tvm/contrib/relay_viz/plotter.py | 20 ++--- .../tvm/contrib/relay_viz/render_callback.py | 79 +++++++++++++------ 4 files changed, 131 insertions(+), 75 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index cfb48390b100..f23088d4a4c4 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -16,12 +16,19 @@ # under the License. """Relay IR Visualizer""" import logging -import copy +from typing import ( + Dict, + Tuple, + Union, +) from enum import Enum +import tvm from tvm import relay from .plotter import Plotter -from .render_callback import RenderCallback - +from .render_callback import ( + RenderCallbackInterface, + RenderCallback, +) _LOGGER = logging.getLogger(__name__) @@ -36,7 +43,10 @@ class PlotterBackend(Enum): class RelayVisualizer: """Relay IR Visualizer""" - def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH): + def __init__(self, + relay_mod: tvm.IRModule, + relay_param: Dict = None, + backend: Union[PlotterBackend, Tuple[Plotter, RenderCallbackInterface]] = PlotterBackend.TERMINAL): """Visualize Relay IR. Parameters @@ -46,19 +56,18 @@ def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH): relay_param: dict Relay parameter dictionary backend: PlotterBackend or a tuple - PlotterBackend: The backend of plotting. Default "bokeh" - Tuple: A tuple with two arguments. First is user-defined Plotter, \ + PlotterBackend: The backend of plotting. Default "terminal" + Tuple: A tuple with two arguments. First is user-defined Plotter, the second is user-defined RenderCallback """ self._plotter, self._render_rules = get_plotter_and_render_rules(backend) self._relay_param = relay_param if relay_param is not None else {} - # This field is used for book-keeping for each graph. - self._node_to_id = {} global_vars = relay_mod.get_global_vars() graph_names = [] # If we have main function, put it to the first. + # Then main function can be shown at the top. for gv_name in global_vars: if gv_name.name_hint == "main": graph_names.insert(0, gv_name.name_hint) @@ -66,21 +75,16 @@ def __init__(self, relay_mod, relay_param=None, backend=PlotterBackend.BOKEH): graph_names.append(gv_name.name_hint) for name in graph_names: - # clear previous graph - self._node_to_id = {} - relay.analysis.post_order_visit( - relay_mod[name], - lambda node: self._traverse_expr(node), # pylint: disable=unnecessary-lambda - ) - graph = self._plotter.create_graph(name) - # shallow copy to prevent callback modify self._node_to_id - self._render_cb(graph, copy.copy(self._node_to_id), self._relay_param) + node_to_id = {} + def traverse_expr(node): + if node in node_to_id: + return + node_to_id[node] = len(node_to_id) - def _traverse_expr(self, node): - # based on https://github.com/apache/tvm/pull/4370 - if node in self._node_to_id: - return - self._node_to_id[node] = len(self._node_to_id) + relay.analysis.post_order_visit(relay_mod[name], traverse_expr) + graph = self._plotter.create_graph(name) + # shallow copy to prevent callback modify node_to_id + self._render_cb(graph, node_to_id.copy(), self._relay_param) def _render_cb(self, graph, node_to_id, relay_param): """a callback to Add nodes and edges to the graph. @@ -93,10 +97,8 @@ def _render_cb(self, graph, node_to_id, relay_param): relay_param : Dict[string, NDarray] """ - # Based on https://tvm.apache.org/2020/07/14/bert-pytorch-tvm - unknown_type = "unknown" for node, node_id in node_to_id.items(): - if type(node) in self._render_rules: # pylint: disable=unidiomatic-typecheck + try: graph_info, edge_info = self._render_rules[type(node)]( node, relay_param, node_to_id ) @@ -104,13 +106,18 @@ def _render_cb(self, graph, node_to_id, relay_param): graph.node(*graph_info) for edge in edge_info: graph.edge(*edge) - else: - unknown_info = "Unknown node: {}".format(type(node)) - _LOGGER.warning(unknown_info) + except KeyError as excp: + unknown_type = "unknown" + unknown_info = f"Failed to render node: {type(node)}" + _LOGGER.warning("When invoking render rule for %s, " + "KeyError with args=%s is raised.", + type(node), + excp.args, + ) graph.node(node_id, unknown_type, unknown_info) - def render(self, filename): - return self._plotter.render(filename=filename) + def render(self, filename: str) -> None: + self._plotter.render(filename=filename) def get_plotter_and_render_rules(backend): @@ -123,17 +130,20 @@ def get_plotter_and_render_rules(backend): Tuple: A tuple with two arguments. First is user-defined Plotter, \ the second is user-defined RenderCallback """ - if type(backend) is tuple and len(backend) == 2: # pylint: disable=unidiomatic-typecheck + if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): - raise ValueError("First elemnet of the backend should be a plotter") + raise ValueError(f"First elemnet of backend argument should be derived from {type(Plotter)}") plotter = backend[0] if not isinstance(backend[1], RenderCallback): - raise ValueError("Second elemnet of the backend should be a callback") + raise ValueError(f"Second elemnet of backend argument should be derived from {type(RenderCallbackInterface)}") render = backend[1] render_rules = render.get_rules() return plotter, render_rules if backend in PlotterBackend: + # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. + # Basically we want to keep them as optional -- users can choose which plotter they want, + # and just install libraries required by that plotter. if backend == PlotterBackend.BOKEH: # pylint: disable=import-outside-toplevel from ._bokeh import ( @@ -157,4 +167,4 @@ def get_plotter_and_render_rules(backend): render_rules = render.get_rules() return plotter, render_rules - raise ValueError("Unknown plotter backend {}".format(backend)) + raise ValueError(f"Unknown plotter backend {backend}") diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index 0cff967de860..487dec19a9ca 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -80,10 +80,19 @@ def detail(self): class GraphShaper: - """Provide the bounding-box, and node location, height, width given by pygraphviz.""" + """Provide the bounding-box, and node location, height, width given by pydot. + To access node attributes, refer to + https://github.com/pydot/pydot/blob/90936e75462c7b0e4bb16d97c1ae7efdf04e895c/src/pydot/core.py#L537 + + To access edge attributes, refer to + https://github.com/pydot/pydot/blob/90936e75462c7b0e4bb16d97c1ae7efdf04e895c/src/pydot/core.py#L645 + + The string format `pos` in an edge follows DOT language spec: + https://graphviz.org/docs/attr-types/splineType/ + """ # defined by graphviz. - _px_per_inch = 72 + _PX_PER_INCH = 72 def __init__(self, pydot_graph, prog="dot", args=None): if args is None: @@ -106,7 +115,13 @@ def get_nodes(self): @functools.lru_cache() def get_edge_path(self, start_node_id, end_node_id): - """Get explicit path points for MultiLine.""" + """Get explicit path points for MultiLine. + Parse points formating an edge. The format of points in an edge is either: + 1. e,x_point,y_point + 2. s,x_point,y_point + 3. x_point,y_point + We don't care about `e` or `s` here, so simplt parse out x_point and y_point. + """ edge = self._pydot_graph.get_edge(str(start_node_id), str(end_node_id)) if len(edge) != 1: _LOGGER.warning( @@ -150,11 +165,11 @@ def get_node_pos(self, node_name): def get_node_height(self, node_name): height_str = self._get_node_attr(node_name, "height", "20") - return float(height_str) * self._px_per_inch + return float(height_str) * self._PX_PER_INCH def get_node_width(self, node_name): width_str = self._get_node_attr(node_name, "width", "20") - return float(width_str) * self._px_per_inch + return float(width_str) * self._PX_PER_INCH def _get_node_attr(self, node_name, attr_name, default_val): @@ -443,14 +458,14 @@ def __init__(self): def create_graph(self, name): if name in self._name_to_graph: - _LOGGER.warning("Graph name %s exists. ") + _LOGGER.warning("Graph name %s already exists.", name) else: self._name_to_graph[name] = BokehGraph() return self._name_to_graph[name] def render(self, filename): if not filename.endswith(".html"): - filename = "{}.html".format(filename) + filename = f"{filename}.html" dom_list = [] for name, graph in self._name_to_graph.items(): diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 8e9dc49a29e0..5cd115d8ea49 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -16,7 +16,7 @@ # under the License. """Abstract class for plotters.""" import abc - +from typing import Union class Graph(abc.ABC): """Abstract class for graph. @@ -25,31 +25,31 @@ class Graph(abc.ABC): """ @abc.abstractmethod - def node(self, node_id, node_type, node_detail): + def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> None: """Add a node to the underlying graph. Parameters ---------- - node_id : object + node_id : Union[int, str] Serve as the ID to the node. - node_type : string + node_type : str the type of the node. - node_detail : string + node_detail : str the description of the node. """ @abc.abstractmethod - def edge(self, id_start, id_end): + def edge(self, id_start: Union[int, str], id_end: Union[int, str]): """Add an edge to the underlying graph. Parameters ---------- - id_start : object + id_start : Union[int, str] the ID to the starting node. - id_end : object + id_end : Union[int, str] the ID to the ending node. """ @@ -61,7 +61,7 @@ class Plotter(abc.ABC): """ @abc.abstractmethod - def create_graph(self, name): + def create_graph(self, name: str) -> Graph: """Create a graph Parameters @@ -74,7 +74,7 @@ def create_graph(self, name): """ @abc.abstractmethod - def render(self, filename): + def render(self, filename:str) -> None: """Render the graph as a file. Parameters diff --git a/python/tvm/contrib/relay_viz/render_callback.py b/python/tvm/contrib/relay_viz/render_callback.py index a2661db6085f..e6ca686b8719 100644 --- a/python/tvm/contrib/relay_viz/render_callback.py +++ b/python/tvm/contrib/relay_viz/render_callback.py @@ -14,15 +14,46 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Default render callback rules""" +"""RenderCallback interface""" +import abc +from typing import ( + Dict, + Callable, + Union, + List, +) import tvm from tvm import relay +from tvm.relay.expr import Tuple UNKNOWN_TYPE = "unknown" - -class RenderCallback: - """This is the default callback rules, which is also the _bokeh backend drawing way""" +class RenderCallbackInterface(abc.ABC): + + @abc.abstractmethod + def get_rules(self) -> Dict[ + tvm.ir.op.Op, + Callable[ + [ + tvm.ir.op.Op, + Dict[str, tvm.runtime.NDArray], + Dict[tvm.ir.op.Op, Union[int, str]] + ], + Tuple[List, List], + ] + ]: + """Retrun a dictionary. Relay node type as key and a callable as valeu. + The callable object should return Tuple[List, List], + where the first List is [node_id, node_type, node_detail] used by Plotter interface. + The second one is for edges, with the form [[e0_start, e0_end], [e1_start, e1_end], ...] + """ + pass + +class RenderCallback(RenderCallbackInterface): + """RenderCallback generate nodes and edges information for each Relay type. + This class is a default implementation for common relay types, heavily based on + `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm + """ def __init__(self): self.render_rules = {} @@ -43,15 +74,15 @@ def var_node(self, node, relay_param, node_to_id): node_detail = "name_hint: {}\ntype_annotation: {}".format( name_hint, node.type_annotation ) - graph_info = [node_id, node_type, node_detail] + node_info = [node_id, node_type, node_detail] edge_info = [] - return graph_info, edge_info + return node_info, edge_info def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] - graph_info = [node_id, "Func", str(node.params)] + node_info = [node_id, "Func", str(node.params)] edge_info = [[node_to_id[node.body], node_id]] - return graph_info, edge_info + return node_info, edge_info def call_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument """Render rule for a relay call node""" @@ -76,51 +107,51 @@ def call_node(self, node, relay_param, node_to_id): # pylint: disable=unused-ar else: op_name = str(type(node.op)).split(".")[-1].split("'")[0] - graph_info = [node_id, f"Call {op_name}", "\n".join(node_detail)] + node_info = [node_id, f"Call {op_name}", "\n".join(node_detail)] args = [node_to_id[arg] for arg in node.args] edge_info = [[arg, node_id] for arg in args] - return graph_info, edge_info + return node_info, edge_info def let_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] - graph_info = [node_id, "Let", ""] + node_info = [node_id, "Let", ""] edge_info = [[node_to_id[node.value], node_id]] edge_info.append([node_id, node_to_id[node.var]]) - return graph_info, edge_info + return node_info, edge_info def global_var_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - graph_info = [] + node_info = [] edge_info = [] - return graph_info, edge_info + return node_info, edge_info def if_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - graph_info = [] + node_info = [] edge_info = [] - return graph_info, edge_info + return node_info, edge_info def tuple_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] - graph_info = [node_id, "Tuple", ""] + node_info = [node_id, "Tuple", ""] edge_info = [[node_to_id[field], node_id] for field in node.fields] - return graph_info, edge_info + return node_info, edge_info def tuple_get_item_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] - graph_info = [node_id, "TupleGetItem", "idx: {}".format(node.index)] + node_info = [node_id, "TupleGetItem", "idx: {}".format(node.index)] edge_info = [[node_to_id[node.tuple_value], node_id]] - return graph_info, edge_info + return node_info, edge_info def constant_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument node_id = node_to_id[node] node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - graph_info = [node_id, "Const", "\n".join(node_detail)] + node_info = [node_id, "Const", "\n".join(node_detail)] edge_info = [] - return graph_info, edge_info + return node_info, edge_info def op_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - graph_info = [] + node_info = [] edge_info = [] - return graph_info, edge_info + return node_info, edge_info def build_rules(self): self.render_rules = { From eff3f0ed2daed0b33770a26ff4ba6c1df4b868d2 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 3 Oct 2021 15:55:35 +0000 Subject: [PATCH 09/16] Refactor. TODO: need doc/type-hint --- python/tvm/contrib/relay_viz/__init__.py | 44 +++++++------- python/tvm/contrib/relay_viz/_bokeh.py | 8 ++- python/tvm/contrib/relay_viz/_terminal.py | 19 +++++-- .../{render_callback.py => node_edge_gen.py} | 57 +++++++++++-------- 4 files changed, 69 insertions(+), 59 deletions(-) rename python/tvm/contrib/relay_viz/{render_callback.py => node_edge_gen.py} (81%) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index f23088d4a4c4..c2bd1b2a551f 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -25,14 +25,10 @@ import tvm from tvm import relay from .plotter import Plotter -from .render_callback import ( - RenderCallbackInterface, - RenderCallback, -) +from .node_edge_gen import NodeEdgeGenerator _LOGGER = logging.getLogger(__name__) - class PlotterBackend(Enum): """Enumeration for available plotters.""" @@ -46,7 +42,7 @@ class RelayVisualizer: def __init__(self, relay_mod: tvm.IRModule, relay_param: Dict = None, - backend: Union[PlotterBackend, Tuple[Plotter, RenderCallbackInterface]] = PlotterBackend.TERMINAL): + backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL): """Visualize Relay IR. Parameters @@ -61,7 +57,7 @@ def __init__(self, the second is user-defined RenderCallback """ - self._plotter, self._render_rules = get_plotter_and_render_rules(backend) + self._plotter, self._ne_generator = get_plotter_and_generator(backend) self._relay_param = relay_param if relay_param is not None else {} global_vars = relay_mod.get_global_vars() @@ -84,10 +80,10 @@ def traverse_expr(node): relay.analysis.post_order_visit(relay_mod[name], traverse_expr) graph = self._plotter.create_graph(name) # shallow copy to prevent callback modify node_to_id - self._render_cb(graph, node_to_id.copy(), self._relay_param) + self._render(graph, node_to_id.copy(), self._relay_param) - def _render_cb(self, graph, node_to_id, relay_param): - """a callback to Add nodes and edges to the graph. + def _render(self, graph, node_to_id, relay_param): + """render nodes and edges to the graph. Parameters ---------- @@ -99,7 +95,7 @@ def _render_cb(self, graph, node_to_id, relay_param): """ for node, node_id in node_to_id.items(): try: - graph_info, edge_info = self._render_rules[type(node)]( + graph_info, edge_info = self._ne_generator.get_node_edges( node, relay_param, node_to_id ) if graph_info: @@ -116,12 +112,12 @@ def _render_cb(self, graph, node_to_id, relay_param): ) graph.node(node_id, unknown_type, unknown_info) - def render(self, filename: str) -> None: + def render(self, filename: str = None) -> None: self._plotter.render(filename=filename) -def get_plotter_and_render_rules(backend): - """Specify the Plottor and its render rules +def get_plotter_and_generator(backend): + """Specify the Plottor and its NodeEdgeGenerator Parameters ---------- @@ -134,11 +130,10 @@ def get_plotter_and_render_rules(backend): if not isinstance(backend[0], Plotter): raise ValueError(f"First elemnet of backend argument should be derived from {type(Plotter)}") plotter = backend[0] - if not isinstance(backend[1], RenderCallback): - raise ValueError(f"Second elemnet of backend argument should be derived from {type(RenderCallbackInterface)}") - render = backend[1] - render_rules = render.get_rules() - return plotter, render_rules + if not isinstance(backend[1], NodeEdgeGenerator): + raise ValueError(f"Second elemnet of backend argument should be derived from {type(NodeEdgeGenerator)}") + ne_generator = backend[1] + return plotter, ne_generator if backend in PlotterBackend: # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. @@ -148,23 +143,22 @@ def get_plotter_and_render_rules(backend): # pylint: disable=import-outside-toplevel from ._bokeh import ( BokehPlotter, - BokehRenderCallback, + BokehNodeEdgeGenerator, ) plotter = BokehPlotter() - render = BokehRenderCallback() + ne_generator = BokehNodeEdgeGenerator() elif backend == PlotterBackend.TERMINAL: # pylint: disable=import-outside-toplevel from ._terminal import ( TermPlotter, - TermRenderCallback, + TermNodeEdgeGenerator, ) plotter = TermPlotter() - render = TermRenderCallback() + ne_generator = TermNodeEdgeGenerator() - render_rules = render.get_rules() - return plotter, render_rules + return plotter, ne_generator raise ValueError(f"Unknown plotter backend {backend}") diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index 487dec19a9ca..c65799294d20 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -49,12 +49,12 @@ Graph, ) -from .render_callback import RenderCallback # pylint: disable=import-outside-toplevel +from .node_edge_gen import DefaultNodeEdgeGenerator _LOGGER = logging.getLogger(__name__) -class BokehRenderCallback(RenderCallback): +class BokehNodeEdgeGenerator(DefaultNodeEdgeGenerator): pass @@ -464,7 +464,9 @@ def create_graph(self, name): return self._name_to_graph[name] def render(self, filename): - if not filename.endswith(".html"): + if filename is None: + filename = "bokeh_plotter_result.html" + elif not filename.endswith(".html"): filename = f"{filename}.html" dom_list = [] diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 2732ee4c54e0..1ecbca2aa9e9 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -24,11 +24,11 @@ Graph, ) -from .render_callback import RenderCallback +from .node_edge_gen import DefaultNodeEdgeGenerator -class TermRenderCallback(RenderCallback): - """Terminal render callback""" +class TermNodeEdgeGenerator(DefaultNodeEdgeGenerator): + """Terminal nodes and edges generator.""" def call_node(self, node, relay_param, node_to_id): node_id = node_to_id[node] @@ -68,6 +68,11 @@ def op_node(self, node, relay_param, node_to_id): edge_info = [] return graph_info, edge_info + def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + node_id = node_to_id[node] + node_info = [node_id, "Func", str(node.params)] + edge_info = [[node_to_id[node.body], node_id]] + return node_info, edge_info class Node: def __init__(self, node_type, other_info): @@ -137,9 +142,11 @@ def create_graph(self, name): return self._name_to_graph[name] def render(self, filename): - # if filename == "stdio", print to terminal. - # Otherwise, print to the file lines = [] for name in self._name_to_graph: lines.append(self._name_to_graph[name].render()) - print("\n".join(lines)) + if filename is None: + print("\n".join(lines)) + else: + with open(filename, "w") as fp: + fp.write("\n".join(lines)) diff --git a/python/tvm/contrib/relay_viz/render_callback.py b/python/tvm/contrib/relay_viz/node_edge_gen.py similarity index 81% rename from python/tvm/contrib/relay_viz/render_callback.py rename to python/tvm/contrib/relay_viz/node_edge_gen.py index e6ca686b8719..3cad75da1516 100644 --- a/python/tvm/contrib/relay_viz/render_callback.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -18,39 +18,26 @@ import abc from typing import ( Dict, - Callable, Union, - List, ) import tvm from tvm import relay -from tvm.relay.expr import Tuple UNKNOWN_TYPE = "unknown" -class RenderCallbackInterface(abc.ABC): +class NodeEdgeGenerator(abc.ABC): + """Abstract class generating nodes and edgs for Graph interface.""" @abc.abstractmethod - def get_rules(self) -> Dict[ - tvm.ir.op.Op, - Callable[ - [ - tvm.ir.op.Op, - Dict[str, tvm.runtime.NDArray], - Dict[tvm.ir.op.Op, Union[int, str]] - ], - Tuple[List, List], - ] - ]: - """Retrun a dictionary. Relay node type as key and a callable as valeu. - The callable object should return Tuple[List, List], - where the first List is [node_id, node_type, node_detail] used by Plotter interface. - The second one is for edges, with the form [[e0_start, e0_end], [e1_start, e1_end], ...] - """ + def get_node_edges(self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]): pass -class RenderCallback(RenderCallbackInterface): - """RenderCallback generate nodes and edges information for each Relay type. + +class DefaultNodeEdgeGenerator(NodeEdgeGenerator): + """NodeEdgeGenerator generate for nodes and edges consumed by Graph. This class is a default implementation for common relay types, heavily based on `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm """ @@ -79,8 +66,18 @@ def var_node(self, node, relay_param, node_to_id): return node_info, edge_info def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + node_details = [] + name = "" + func_attrs = node.attrs + if func_attrs: + node_details = [ + "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() + ] + # "Composite" might from relay.transform.MergeComposite + if "Composite" in func_attrs.keys(): + name = func_attrs["Composite"] node_id = node_to_id[node] - node_info = [node_id, "Func", str(node.params)] + node_info = [node_id, f"Func {name}", "\n".join(node_details)] edge_info = [[node_to_id[node.body], node_id]] return node_info, edge_info @@ -167,5 +164,15 @@ def build_rules(self): tvm.ir.Op: self.op_node, } - def get_rules(self): - return self.render_rules + def get_node_edges(self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]): + try: + graph_info, edge_info = self.render_rules[type(node)]( + node, relay_param, node_to_id + ) + except KeyError: + graph_info = [] + edge_info = [] + return graph_info, edge_info From 329110ad9669295352d835ea1cd7d58bef053b86 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 3 Oct 2021 16:06:28 +0000 Subject: [PATCH 10/16] need more refactor/doc/type hint --- python/tvm/contrib/relay_viz/__init__.py | 80 ++++++++----------- python/tvm/contrib/relay_viz/node_edge_gen.py | 5 +- 2 files changed, 37 insertions(+), 48 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index c2bd1b2a551f..9efb6ffd1375 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -93,24 +93,14 @@ def _render(self, graph, node_to_id, relay_param): relay_param : Dict[string, NDarray] """ - for node, node_id in node_to_id.items(): - try: - graph_info, edge_info = self._ne_generator.get_node_edges( - node, relay_param, node_to_id - ) - if graph_info: - graph.node(*graph_info) - for edge in edge_info: - graph.edge(*edge) - except KeyError as excp: - unknown_type = "unknown" - unknown_info = f"Failed to render node: {type(node)}" - _LOGGER.warning("When invoking render rule for %s, " - "KeyError with args=%s is raised.", - type(node), - excp.args, - ) - graph.node(node_id, unknown_type, unknown_info) + for node in node_to_id: + graph_info, edge_info = self._ne_generator.get_node_edges( + node, relay_param, node_to_id + ) + if graph_info: + graph.node(*graph_info) + for edge in edge_info: + graph.edge(*edge) def render(self, filename: str = None) -> None: self._plotter.render(filename=filename) @@ -129,36 +119,34 @@ def get_plotter_and_generator(backend): if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): raise ValueError(f"First elemnet of backend argument should be derived from {type(Plotter)}") - plotter = backend[0] + if not isinstance(backend[1], NodeEdgeGenerator): raise ValueError(f"Second elemnet of backend argument should be derived from {type(NodeEdgeGenerator)}") - ne_generator = backend[1] - return plotter, ne_generator - - if backend in PlotterBackend: - # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. - # Basically we want to keep them as optional -- users can choose which plotter they want, - # and just install libraries required by that plotter. - if backend == PlotterBackend.BOKEH: - # pylint: disable=import-outside-toplevel - from ._bokeh import ( - BokehPlotter, - BokehNodeEdgeGenerator, - ) - - plotter = BokehPlotter() - ne_generator = BokehNodeEdgeGenerator() - - elif backend == PlotterBackend.TERMINAL: - # pylint: disable=import-outside-toplevel - from ._terminal import ( - TermPlotter, - TermNodeEdgeGenerator, - ) - plotter = TermPlotter() - ne_generator = TermNodeEdgeGenerator() + return backend + + if backend not in PlotterBackend: + raise ValueError(f"Unknown plotter backend {backend}") + + # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. + # Basically we want to keep them as optional -- users can choose which plotter they want, + # and just install libraries required by that plotter. + if backend == PlotterBackend.BOKEH: + # pylint: disable=import-outside-toplevel + from ._bokeh import ( + BokehPlotter, + BokehNodeEdgeGenerator, + ) + plotter = BokehPlotter() + ne_generator = BokehNodeEdgeGenerator() + elif backend == PlotterBackend.TERMINAL: + # pylint: disable=import-outside-toplevel + from ._terminal import ( + TermPlotter, + TermNodeEdgeGenerator, + ) + plotter = TermPlotter() + ne_generator = TermNodeEdgeGenerator() + return plotter, ne_generator - return plotter, ne_generator - raise ValueError(f"Unknown plotter backend {backend}") diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index 3cad75da1516..d92d2fb7ef4a 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -19,6 +19,7 @@ from typing import ( Dict, Union, + Tuple, ) import tvm from tvm import relay @@ -32,7 +33,7 @@ class NodeEdgeGenerator(abc.ABC): def get_node_edges(self, node: relay.expr.ExprWithOp, relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]): + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]) -> Tuple[list, list]: pass @@ -173,6 +174,6 @@ def get_node_edges(self, node, relay_param, node_to_id ) except KeyError: - graph_info = [] + graph_info = [node_to_id[node], "unknown", f"failed to parse node: {type(node)}"] edge_info = [] return graph_info, edge_info From 0a2a02bc340dbd101c0cf2329a98812a5969a32d Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 3 Oct 2021 16:31:56 +0000 Subject: [PATCH 11/16] fix lint (todo: tests and tutorial) --- python/tvm/contrib/relay_viz/__init__.py | 47 ++++++++----------- python/tvm/contrib/relay_viz/_terminal.py | 17 +++++-- python/tvm/contrib/relay_viz/node_edge_gen.py | 43 ++++++++++------- python/tvm/contrib/relay_viz/plotter.py | 5 +- 4 files changed, 61 insertions(+), 51 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 9efb6ffd1375..e75d160428e8 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Relay IR Visualizer""" -import logging from typing import ( Dict, Tuple, @@ -27,7 +26,6 @@ from .plotter import Plotter from .node_edge_gen import NodeEdgeGenerator -_LOGGER = logging.getLogger(__name__) class PlotterBackend(Enum): """Enumeration for available plotters.""" @@ -39,10 +37,12 @@ class PlotterBackend(Enum): class RelayVisualizer: """Relay IR Visualizer""" - def __init__(self, - relay_mod: tvm.IRModule, - relay_param: Dict = None, - backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL): + def __init__( + self, + relay_mod: tvm.IRModule, + relay_param: Dict = None, + backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL, + ): """Visualize Relay IR. Parameters @@ -54,7 +54,7 @@ def __init__(self, backend: PlotterBackend or a tuple PlotterBackend: The backend of plotting. Default "terminal" Tuple: A tuple with two arguments. First is user-defined Plotter, - the second is user-defined RenderCallback + the second is user-defined NodeEdgeGenerator """ self._plotter, self._ne_generator = get_plotter_and_generator(backend) @@ -63,15 +63,16 @@ def __init__(self, global_vars = relay_mod.get_global_vars() graph_names = [] # If we have main function, put it to the first. - # Then main function can be shown at the top. + # Then main function can be shown on the top. for gv_name in global_vars: if gv_name.name_hint == "main": graph_names.insert(0, gv_name.name_hint) else: graph_names.append(gv_name.name_hint) + node_to_id = {} for name in graph_names: - node_to_id = {} + def traverse_expr(node): if node in node_to_id: return @@ -79,8 +80,8 @@ def traverse_expr(node): relay.analysis.post_order_visit(relay_mod[name], traverse_expr) graph = self._plotter.create_graph(name) - # shallow copy to prevent callback modify node_to_id - self._render(graph, node_to_id.copy(), self._relay_param) + self._render(graph, node_to_id, self._relay_param) + node_to_id.clear() def _render(self, graph, node_to_id, relay_param): """render nodes and edges to the graph. @@ -94,9 +95,7 @@ def _render(self, graph, node_to_id, relay_param): relay_param : Dict[string, NDarray] """ for node in node_to_id: - graph_info, edge_info = self._ne_generator.get_node_edges( - node, relay_param, node_to_id - ) + graph_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id) if graph_info: graph.node(*graph_info) for edge in edge_info: @@ -107,21 +106,15 @@ def render(self, filename: str = None) -> None: def get_plotter_and_generator(backend): - """Specify the Plottor and its NodeEdgeGenerator - - Parameters - ---------- - backend: PlotterBackend or a tuple - PlotterBackend: The backend of plotting. Default "bokeh" - Tuple: A tuple with two arguments. First is user-defined Plotter, \ - the second is user-defined RenderCallback - """ + """Specify the Plottor and its NodeEdgeGenerator""" if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): - raise ValueError(f"First elemnet of backend argument should be derived from {type(Plotter)}") + raise ValueError(f"First element of backend should be derived from {type(Plotter)}") if not isinstance(backend[1], NodeEdgeGenerator): - raise ValueError(f"Second elemnet of backend argument should be derived from {type(NodeEdgeGenerator)}") + raise ValueError( + f"Second element of backend should be derived from {type(NodeEdgeGenerator)}" + ) return backend @@ -137,6 +130,7 @@ def get_plotter_and_generator(backend): BokehPlotter, BokehNodeEdgeGenerator, ) + plotter = BokehPlotter() ne_generator = BokehNodeEdgeGenerator() elif backend == PlotterBackend.TERMINAL: @@ -145,8 +139,7 @@ def get_plotter_and_generator(backend): TermPlotter, TermNodeEdgeGenerator, ) + plotter = TermPlotter() ne_generator = TermNodeEdgeGenerator() return plotter, ne_generator - - diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 1ecbca2aa9e9..4c6e5c0ca4ac 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -17,7 +17,6 @@ """Visualize Relay IR in AST text-form""" from collections import deque -import functools from .plotter import ( Plotter, @@ -74,6 +73,7 @@ def function_node(self, node, relay_param, node_to_id): # pylint: disable=unuse edge_info = [[node_to_id[node.body], node_id]] return node_info, edge_info + class Node: def __init__(self, node_type, other_info): self.type = node_type @@ -111,14 +111,20 @@ def edge(self, id_start, id_end): def render(self): """To draw a terminal graph""" lines = [] + seen_node = set() - @functools.lru_cache() def gen_line(indent, n_id): + if (indent, n_id) in seen_node: + return + seen_node.add((indent, n_id)) + conn_symbol = ["|--", "`--"] last = len(self._graph[n_id]) - 1 for i, next_n_id in enumerate(self._graph[n_id]): node = self._id_to_node[next_n_id] - lines.append(f"{indent}{conn_symbol[i==last]}{node.type} {node.other_info}") + lines.append( + f"{indent}{conn_symbol[1 if i==last else 0]}{node.type} {node.other_info}" + ) next_indent = indent next_indent += " " if (i == last) else "| " gen_line(next_indent, next_n_id) @@ -142,11 +148,12 @@ def create_graph(self, name): return self._name_to_graph[name] def render(self, filename): + """If filename is None, print to stdio. Otherwise, write to the filename.""" lines = [] for name in self._name_to_graph: lines.append(self._name_to_graph[name].render()) if filename is None: print("\n".join(lines)) else: - with open(filename, "w") as fp: - fp.write("\n".join(lines)) + with open(filename, "w") as out_file: + out_file.write("\n".join(lines)) diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index d92d2fb7ef4a..d635d5bbf82d 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""RenderCallback interface""" +"""NodeEdgeGenerator interface""" import abc from typing import ( Dict, @@ -26,15 +26,25 @@ UNKNOWN_TYPE = "unknown" + class NodeEdgeGenerator(abc.ABC): """Abstract class generating nodes and edgs for Graph interface.""" @abc.abstractmethod - def get_node_edges(self, - node: relay.expr.ExprWithOp, - relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]) -> Tuple[list, list]: - pass + def get_node_edges( + self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[list, list]: + """Function return node and edges consumed by Graph interface + The returned tuple containing two lists, the first list match + the interface of Graph.node(), i.e. `node_id`, `node_type`, and `node_detail`. + The secon list is the form: + [(node_id_start0, node_id_end0), ...] + where the tuple `(node_id_start0, node_id_end0)` represent an edge from + `node_id_start0` to `node_id_end0`. + """ class DefaultNodeEdgeGenerator(NodeEdgeGenerator): @@ -67,13 +77,12 @@ def var_node(self, node, relay_param, node_to_id): return node_info, edge_info def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + """Render rule for a relay function node""" node_details = [] name = "" func_attrs = node.attrs if func_attrs: - node_details = [ - "{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys() - ] + node_details = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): name = func_attrs["Composite"] @@ -165,15 +174,15 @@ def build_rules(self): tvm.ir.Op: self.op_node, } - def get_node_edges(self, - node: relay.expr.ExprWithOp, - relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]]): + def get_node_edges( + self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[list, list]: try: - graph_info, edge_info = self.render_rules[type(node)]( - node, relay_param, node_to_id - ) + graph_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) except KeyError: - graph_info = [node_to_id[node], "unknown", f"failed to parse node: {type(node)}"] + graph_info = [node_to_id[node], UNKNOWN_TYPE, f"failed to parse node: {type(node)}"] edge_info = [] return graph_info, edge_info diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 5cd115d8ea49..419a41628f89 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -18,6 +18,7 @@ import abc from typing import Union + class Graph(abc.ABC): """Abstract class for graph. @@ -41,7 +42,7 @@ def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> No """ @abc.abstractmethod - def edge(self, id_start: Union[int, str], id_end: Union[int, str]): + def edge(self, id_start: Union[int, str], id_end: Union[int, str]) -> None: """Add an edge to the underlying graph. Parameters @@ -74,7 +75,7 @@ def create_graph(self, name: str) -> Graph: """ @abc.abstractmethod - def render(self, filename:str) -> None: + def render(self, filename: str) -> None: """Render the graph as a file. Parameters From 6533aea25a65ee7cf80165fee2c48ebc6befcc56 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 11 Oct 2021 15:38:15 +0000 Subject: [PATCH 12/16] refactor, doc. TODO: tests, tutorial --- python/tvm/contrib/relay_viz/README.md | 18 +- python/tvm/contrib/relay_viz/__init__.py | 50 +++--- python/tvm/contrib/relay_viz/_bokeh.py | 29 ++-- python/tvm/contrib/relay_viz/_terminal.py | 115 +++++++++---- python/tvm/contrib/relay_viz/node_edge_gen.py | 158 ++++++++++++------ python/tvm/contrib/relay_viz/plotter.py | 30 +--- tests/python/contrib/test_relay_viz.py | 52 ++++++ 7 files changed, 306 insertions(+), 146 deletions(-) create mode 100644 tests/python/contrib/test_relay_viz.py diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md index 5a189f85cbb3..bb6e964e8f07 100644 --- a/python/tvm/contrib/relay_viz/README.md +++ b/python/tvm/contrib/relay_viz/README.md @@ -24,6 +24,7 @@ This tool target to visualize Relay IR. 1. [Requirement](#Requirement) 2. [Usage](#Usage) 3. [Credits](#Credits) +4. [Design and Customization](#Design-and-Customization) ## Requirement @@ -47,7 +48,7 @@ pip install pydot bokeh==2.3.1 ``` from tvm.contrib import relay_viz mod, params = tvm.relay.frontend.from_onnx(net, shape_dict) -vizer = relay_viz.RelayVisualizer(mod, relay_param=params) +vizer = relay_viz.RelayVisualizer(mod, relay_param=params, backend=PlotterBackend.BOKEH) vizer.render("output.html") ``` @@ -57,4 +58,17 @@ vizer.render("output.html") 2. https://tvm.apache.org/2020/07/14/bert-pytorch-tvm -3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17 \ No newline at end of file +3. https://discuss.tvm.apache.org/t/rfc-visualizing-relay-program-as-graph/4825/17 + +## Design and Customization + +This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`. + +`plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`. + +`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes/edges consumed by `Graph`. Further, this python module also provide a default implementation for common relay types. + +If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage. + +These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes/edges to `Graph`. +Then, it render the plot by `Plotter.render()`. diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index e75d160428e8..658662264cf9 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -40,21 +40,16 @@ class RelayVisualizer: def __init__( self, relay_mod: tvm.IRModule, - relay_param: Dict = None, + relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None, backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL, ): """Visualize Relay IR. Parameters ---------- - relay_mod : object - Relay IR module - relay_param: dict - Relay parameter dictionary - backend: PlotterBackend or a tuple - PlotterBackend: The backend of plotting. Default "terminal" - Tuple: A tuple with two arguments. First is user-defined Plotter, - the second is user-defined NodeEdgeGenerator + relay_mod : tvm.IRModule, Relay IR module + relay_param: None | Dict[str, tvm.runtime.NDArray], Relay parameter dictionary. Default `None`. + backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator], Default `PlotterBackend.TERMINAL`. """ self._plotter, self._ne_generator = get_plotter_and_generator(backend) @@ -71,35 +66,33 @@ def __init__( graph_names.append(gv_name.name_hint) node_to_id = {} - for name in graph_names: - def traverse_expr(node): - if node in node_to_id: - return - node_to_id[node] = len(node_to_id) + def traverse_expr(node): + if node in node_to_id: + return + node_to_id[node] = len(node_to_id) + for name in graph_names: + node_to_id.clear() relay.analysis.post_order_visit(relay_mod[name], traverse_expr) graph = self._plotter.create_graph(name) - self._render(graph, node_to_id, self._relay_param) - node_to_id.clear() + self._add_nodes(graph, node_to_id, self._relay_param) - def _render(self, graph, node_to_id, relay_param): - """render nodes and edges to the graph. + def _add_nodes(self, graph, node_to_id, relay_param): + """add nodes and to the graph. Parameters ---------- - graph : class plotter.Graph - - node_to_id : Dict[relay.expr, int] - - relay_param : Dict[string, NDarray] + graph : `plotter.Graph` + node_to_id : Dict[relay.expr, str | int] + relay_param : Dict[str, tvm.runtime.NDarray] """ for node in node_to_id: - graph_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id) - if graph_info: - graph.node(*graph_info) + node_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id) + if node_info is not None: + graph.node(node_info.node_id, node_info.node_type, node_info.node_detail) for edge in edge_info: - graph.edge(*edge) + graph.edge(edge.start, edge.end) def render(self, filename: str = None) -> None: self._plotter.render(filename=filename) @@ -122,8 +115,7 @@ def get_plotter_and_generator(backend): raise ValueError(f"Unknown plotter backend {backend}") # Plotter modules are Lazy-imported to avoid they become a requirement of TVM. - # Basically we want to keep them as optional -- users can choose which plotter they want, - # and just install libraries required by that plotter. + # Basically we want to keep them optional. Users can choose plotters they want to install. if backend == PlotterBackend.BOKEH: # pylint: disable=import-outside-toplevel from ._bokeh import ( diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/_bokeh.py index c65799294d20..58716510c91a 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/_bokeh.py @@ -16,13 +16,25 @@ # under the License. """Bokeh backend for Relay IR Visualizer.""" import html -import logging import functools +import logging + +_LOGGER = logging.getLogger(__name__) import numpy as np -import pydot -from bokeh.io import output_file, save +try: + import pydot +except ImportError: + _LOGGER.critical("pydot library is required. You might want to run pip install pydot.") + raise + +try: + from bokeh.io import output_file, save +except ImportError: + _LOGGER.critical("bokeh library is required. You might want to run pip install bokeh.") + raise + from bokeh.models import ( ColumnDataSource, CustomJS, @@ -51,11 +63,8 @@ from .node_edge_gen import DefaultNodeEdgeGenerator -_LOGGER = logging.getLogger(__name__) - - -class BokehNodeEdgeGenerator(DefaultNodeEdgeGenerator): - pass +# Use default node/edge generator +BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator class NodeDescriptor: @@ -199,7 +208,7 @@ def _get_node_attr(self, node_name, attr_name, default_val): class BokehGraph(Graph): - """Use Bokeh library to plot Relay IR.""" + """Use Bokeh library to plot networks, i.e. nodes and edges.""" def __init__(self): self._pydot_digraph = pydot.Dot(graph_type="digraph") @@ -451,7 +460,7 @@ def _get_graph_name(plot): class BokehPlotter(Plotter): - """Use Bokeh library to plot Relay IR.""" + """Render and save collections of class BokehGraph.""" def __init__(self): self._name_to_graph = {} diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 4c6e5c0ca4ac..3b7bd1b30f50 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -17,64 +17,113 @@ """Visualize Relay IR in AST text-form""" from collections import deque +from typing import ( + Dict, + Union, + Tuple, + List, +) + +import tvm +from tvm import relay from .plotter import ( Plotter, Graph, ) -from .node_edge_gen import DefaultNodeEdgeGenerator +from .node_edge_gen import ( + Node, + Edge, + NodeEdgeGenerator, + DefaultNodeEdgeGenerator, +) -class TermNodeEdgeGenerator(DefaultNodeEdgeGenerator): +class TermNodeEdgeGenerator(NodeEdgeGenerator): """Terminal nodes and edges generator.""" - def call_node(self, node, relay_param, node_to_id): + def __init__(self): + self._default_ne_gen = DefaultNodeEdgeGenerator() + + def get_node_edges( + self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: + """Generate node and edges consumed by TermGraph interfaces""" + if isinstance(node, relay.Call): + return self._call_node(node, node_to_id) + + if isinstance(node, relay.Let): + return self._let_node(node, node_to_id) + + if isinstance(node, relay.GlobalVar): + return self._global_var_node(node, node_to_id) + + if isinstance(node, relay.If): + return self._if_node(node, node_to_id) + + if isinstance(node, tvm.ir.Op): + return self._op_node(node, node_to_id) + + if isinstance(node, relay.Function): + return self._function_node(node, node_to_id) + + # otherwise, delegate to the default implementation + return self._default_ne_gen.get_node_edges(node, relay_param, node_to_id) + + def _call_node(self, node, node_to_id): node_id = node_to_id[node] - graph_info = [node_id, "Call", ""] - edge_info = [[node_to_id[node.op], node_id]] - args = [node_to_id[arg] for arg in node.args] - for arg in args: - edge_info.append([arg, node_id]) - return graph_info, edge_info - - def let_node(self, node, relay_param, node_to_id): + node_info = Node(node_id, "Call", "") + edge_info = [Edge(node_to_id[node.op], node_id)] + for arg in node.args: + arg_nid = node_to_id[arg] + edge_info.append(Edge(arg_nid, node_id)) + return node_info, edge_info + + def _let_node(self, node, node_to_id): node_id = node_to_id[node] - graph_info = [node_id, "Let", "(var, val, body)"] - edge_info = [[node_to_id[node.var], node_id]] - edge_info.append([node_to_id[node.value], node_id]) - edge_info.append([node_to_id[node.body], node_id]) - return graph_info, edge_info + node_info = Node(node_id, "Let", "(var, val, body)") + edge_info = [ + Edge(node_to_id[node.var], node_id), + Edge(node_to_id[node.value], node_id), + Edge(node_to_id[node.body], node_id), + ] + return node_info, edge_info - def global_var_node(self, node, relay_param, node_to_id): + def _global_var_node(self, node, node_to_id): node_id = node_to_id[node] - graph_info = [node_id, "GlobalVar", node.name_hint] + node_info = Node(node_id, "GlobalVar", node.name_hint) edge_info = [] - return graph_info, edge_info + return node_info, edge_info - def if_node(self, node, relay_param, node_to_id): + def _if_node(self, node, node_to_id): node_id = node_to_id[node] - graph_info = [node_id, "If", "(cond, true, false)"] - edge_info = [[node_to_id[node.cond], node_id]] - edge_info.append([node_to_id[node.true_branch], node_id]) - edge_info.append([node_to_id[node.false_branch], node_id]) - return graph_info, edge_info + node_info = Node(node_id, "If", "(cond, true, false)") + edge_info = [ + Edge(node_to_id[node.cond], node_id), + Edge(node_to_id[node.true_branch], node_id), + Edge(node_to_id[node.false_branch], node_id), + ] + return node_info, edge_info - def op_node(self, node, relay_param, node_to_id): + def _op_node(self, node, node_to_id): node_id = node_to_id[node] op_name = node.name - graph_info = [node_id, op_name, ""] + node_info = Node(node_id, op_name, "") edge_info = [] - return graph_info, edge_info + return node_info, edge_info - def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def _function_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = [node_id, "Func", str(node.params)] - edge_info = [[node_to_id[node.body], node_id]] + node_info = Node(node_id, "Func", str(node.params)) + edge_info = [Edge(node_to_id[node.body], node_id)] return node_info, edge_info -class Node: +class TermNode: def __init__(self, node_type, other_info): self.type = node_type self.other_info = other_info.replace("\n", ", ") @@ -98,7 +147,7 @@ def node(self, node_id, node_type, node_detail): if node_id not in self._graph: self._graph[node_id] = [] - node = Node(node_type, node_detail) + node = TermNode(node_type, node_detail) self._id_to_node[node_id] = node def edge(self, id_start, id_end): diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index d635d5bbf82d..71ff13b3397e 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -20,6 +20,7 @@ Dict, Union, Tuple, + List, ) import tvm from tvm import relay @@ -27,8 +28,45 @@ UNKNOWN_TYPE = "unknown" +class Node: + """Node carry information used by `plotter.Graph` interface.""" + + def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): + self._node_id = node_id + self._node_type = node_type + self._node_detail = node_detail + + @property + def node_id(self) -> Union[int, str]: + return self._node_id + + @property + def node_type(self) -> str: + return self._node_type + + @property + def node_detail(self) -> str: + return self._node_detail + + +class Edge: + """Edge for `plotter.Graph` interface.""" + + def __init__(self, start_node: Union[int, str], end_node: Union[int, str]): + self._start_node = start_node + self._end_node = end_node + + @property + def start(self) -> Union[int, str]: + return self._start_node + + @property + def end(self) -> Union[int, str]: + return self._end_node + + class NodeEdgeGenerator(abc.ABC): - """Abstract class generating nodes and edgs for Graph interface.""" + """An interface class to generate nodes and edges information for Graph interfaces.""" @abc.abstractmethod def get_node_edges( @@ -36,14 +74,10 @@ def get_node_edges( node: relay.expr.ExprWithOp, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], - ) -> Tuple[list, list]: - """Function return node and edges consumed by Graph interface - The returned tuple containing two lists, the first list match - the interface of Graph.node(), i.e. `node_id`, `node_type`, and `node_detail`. - The secon list is the form: - [(node_id_start0, node_id_end0), ...] - where the tuple `(node_id_start0, node_id_end0)` represent an edge from - `node_id_start0` to `node_id_end0`. + ) -> Tuple[Union[Node, None], List[Edge]]: + """Generate node and edges consumed by Graph interfaces + The returned tuple containing Node and a list of Edge instances. + Tuple[None, list[]] for null results. """ @@ -57,7 +91,12 @@ def __init__(self): self.render_rules = {} self.build_rules() - def var_node(self, node, relay_param, node_to_id): + def var_node( + self, + node: relay.expr.ExprWithOp, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay var node""" node_id = node_to_id[node] name_hint = node.name_hint @@ -72,11 +111,16 @@ def var_node(self, node, relay_param, node_to_id): node_detail = "name_hint: {}\ntype_annotation: {}".format( name_hint, node.type_annotation ) - node_info = [node_id, node_type, node_detail] + node_info = Node(node_id, node_type, node_detail) edge_info = [] return node_info, edge_info - def function_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def function_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay function node""" node_details = [] name = "" @@ -87,11 +131,16 @@ def function_node(self, node, relay_param, node_to_id): # pylint: disable=unuse if "Composite" in func_attrs.keys(): name = func_attrs["Composite"] node_id = node_to_id[node] - node_info = [node_id, f"Func {name}", "\n".join(node_details)] - edge_info = [[node_to_id[node.body], node_id]] + node_info = Node(node_id, f"Func {name}", "\n".join(node_details)) + edge_info = [Edge(node_to_id[node.body], node_id)] return node_info, edge_info - def call_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def call_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay call node""" node_id = node_to_id[node] op_name = UNKNOWN_TYPE @@ -114,51 +163,58 @@ def call_node(self, node, relay_param, node_to_id): # pylint: disable=unused-ar else: op_name = str(type(node.op)).split(".")[-1].split("'")[0] - node_info = [node_id, f"Call {op_name}", "\n".join(node_detail)] + node_info = Node(node_id, f"Call {op_name}", "\n".join(node_detail)) args = [node_to_id[arg] for arg in node.args] - edge_info = [[arg, node_id] for arg in args] + edge_info = [Edge(arg, node_id) for arg in args] return node_info, edge_info - def let_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def let_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] - node_info = [node_id, "Let", ""] - edge_info = [[node_to_id[node.value], node_id]] - edge_info.append([node_id, node_to_id[node.var]]) - return node_info, edge_info - - def global_var_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - node_info = [] - edge_info = [] + node_info = Node(node_id, "Let", "") + edge_info = [Edge(node_to_id[node.value], node_id), Edge(node_id, node_to_id[node.var])] return node_info, edge_info - def if_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - node_info = [] - edge_info = [] - return node_info, edge_info - - def tuple_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def tuple_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] - node_info = [node_id, "Tuple", ""] - edge_info = [[node_to_id[field], node_id] for field in node.fields] + node_info = Node(node_id, "Tuple", "") + edge_info = [Edge(node_to_id[field], node_id) for field in node.fields] return node_info, edge_info - def tuple_get_item_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def tuple_get_item_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] - node_info = [node_id, "TupleGetItem", "idx: {}".format(node.index)] - edge_info = [[node_to_id[node.tuple_value], node_id]] + node_info = Node(node_id, "TupleGetItem", "idx: {}".format(node.index)) + edge_info = [Edge(node_to_id[node.tuple_value], node_id)] return node_info, edge_info - def constant_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument + def constant_node( + self, + node: relay.expr.ExprWithOp, + _: Dict[str, tvm.runtime.NDArray], # relay_param + node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - node_info = [node_id, "Const", "\n".join(node_detail)] + node_info = Node(node_id, "Const", "\n".join(node_detail)) edge_info = [] return node_info, edge_info - def op_node(self, node, relay_param, node_to_id): # pylint: disable=unused-argument - node_info = [] - edge_info = [] - return node_info, edge_info + def null(self, *_) -> Tuple[None, List[Edge]]: + return None, [] def build_rules(self): self.render_rules = { @@ -166,12 +222,12 @@ def build_rules(self): tvm.relay.expr.Call: self.call_node, tvm.relay.expr.Let: self.let_node, tvm.relay.expr.Var: self.var_node, - tvm.relay.expr.GlobalVar: self.global_var_node, - tvm.relay.expr.If: self.if_node, tvm.relay.expr.Tuple: self.tuple_node, tvm.relay.expr.TupleGetItem: self.tuple_get_item_node, tvm.relay.expr.Constant: self.constant_node, - tvm.ir.Op: self.op_node, + tvm.relay.expr.If: self.null, + tvm.relay.expr.GlobalVar: self.null, + tvm.ir.Op: self.null, } def get_node_edges( @@ -179,10 +235,12 @@ def get_node_edges( node: relay.expr.ExprWithOp, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], - ) -> Tuple[list, list]: + ) -> Tuple[Union[Node, None], List[Edge]]: try: - graph_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) + node_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) except KeyError: - graph_info = [node_to_id[node], UNKNOWN_TYPE, f"failed to parse node: {type(node)}"] + node_info = Node( + node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" + ) edge_info = [] - return graph_info, edge_info + return node_info, edge_info diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 419a41628f89..58ed2c02ebd9 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -20,10 +20,7 @@ class Graph(abc.ABC): - """Abstract class for graph. - - Implement this interface for various graph libraries. - """ + """Abstract class for graph, which is composed of nodes and edges.""" @abc.abstractmethod def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> None: @@ -31,14 +28,9 @@ def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> No Parameters ---------- - node_id : Union[int, str] - Serve as the ID to the node. - - node_type : str - the type of the node. - - node_detail : str - the description of the node. + node_id : Union[int, str], Serve as the ID to the node. + node_type : str, the type of the node. + node_detail : str, the description of the node. """ @abc.abstractmethod @@ -47,19 +39,13 @@ def edge(self, id_start: Union[int, str], id_end: Union[int, str]) -> None: Parameters ---------- - id_start : Union[int, str] - the ID to the starting node. - - id_end : Union[int, str] - the ID to the ending node. + id_start : Union[int, str], the ID to the starting node. + id_end : Union[int, str], the ID to the ending node. """ class Plotter(abc.ABC): - """Abstract class for plotters. - - Implement this interface for various graph libraries. - """ + """Abstract class for plotters, rendering a collection of Graph interface.""" @abc.abstractmethod def create_graph(self, name: str) -> Graph: @@ -80,5 +66,5 @@ def render(self, filename: str) -> None: Parameters ---------- - filename : string + filename : string, see the definition of implemented class. """ diff --git a/tests/python/contrib/test_relay_viz.py b/tests/python/contrib/test_relay_viz.py new file mode 100644 index 000000000000..0ab393cf6d9c --- /dev/null +++ b/tests/python/contrib/test_relay_viz.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# from tvm import relay +# from tvm.contrib.relay_viz.node_edge_gen import DefaultNodeEdgeGenerator + + +def test_var(): + pass + + +def test_function(): + pass + + +def test_call(): + pass + + +def test_let(): + pass + + +def test_tuple(): + pass + + +def test_constant(): + pass + + +if __name__ == "__main__": + test_var() + test_function() + test_call() + test_let() + test_tuple() + test_constant() From f59a3fef57f5f74ddf1926546bbe17bb4e64e22d Mon Sep 17 00:00:00 2001 From: chiwwang Date: Sun, 17 Oct 2021 14:33:01 +0000 Subject: [PATCH 13/16] add testing for default parser. (TODO: tutorial) --- python/tvm/contrib/relay_viz/__init__.py | 2 +- python/tvm/contrib/relay_viz/_terminal.py | 4 +- python/tvm/contrib/relay_viz/node_edge_gen.py | 65 ++++++++----------- tests/python/contrib/test_relay_viz.py | 63 ++++++++++++++---- 4 files changed, 80 insertions(+), 54 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 658662264cf9..855adeb2e8ee 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -90,7 +90,7 @@ def _add_nodes(self, graph, node_to_id, relay_param): for node in node_to_id: node_info, edge_info = self._ne_generator.get_node_edges(node, relay_param, node_to_id) if node_info is not None: - graph.node(node_info.node_id, node_info.node_type, node_info.node_detail) + graph.node(node_info.identity, node_info.type_str, node_info.detail) for edge in edge_info: graph.edge(edge.start, edge.end) diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/_terminal.py index 3b7bd1b30f50..230bd88f5a52 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/_terminal.py @@ -48,9 +48,9 @@ def __init__(self): def get_node_edges( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: """Generate node and edges consumed by TermGraph interfaces""" if isinstance(node, relay.Call): diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index 71ff13b3397e..ffe54e741897 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -32,21 +32,21 @@ class Node: """Node carry information used by `plotter.Graph` interface.""" def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): - self._node_id = node_id - self._node_type = node_type - self._node_detail = node_detail + self._id = node_id + self._type = node_type + self._detail = node_detail @property - def node_id(self) -> Union[int, str]: - return self._node_id + def identity(self) -> Union[int, str]: + return self._id @property - def node_type(self) -> str: - return self._node_type + def type_str(self) -> str: + return self._type @property - def node_detail(self) -> str: - return self._node_detail + def detail(self) -> str: + return self._detail class Edge: @@ -71,9 +71,9 @@ class NodeEdgeGenerator(abc.ABC): @abc.abstractmethod def get_node_edges( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: """Generate node and edges consumed by Graph interfaces The returned tuple containing Node and a list of Edge instances. @@ -93,9 +93,9 @@ def __init__(self): def var_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay var node""" node_id = node_to_id[node] @@ -117,9 +117,9 @@ def var_node( def function_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay function node""" node_details = [] @@ -137,9 +137,9 @@ def function_node( def call_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: """Render rule for a relay call node""" node_id = node_to_id[node] @@ -168,22 +168,11 @@ def call_node( edge_info = [Edge(arg, node_id) for arg in args] return node_info, edge_info - def let_node( - self, - node: relay.expr.ExprWithOp, - _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: - node_id = node_to_id[node] - node_info = Node(node_id, "Let", "") - edge_info = [Edge(node_to_id[node.value], node_id), Edge(node_id, node_to_id[node.var])] - return node_info, edge_info - def tuple_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] node_info = Node(node_id, "Tuple", "") @@ -192,9 +181,9 @@ def tuple_node( def tuple_get_item_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] node_info = Node(node_id, "TupleGetItem", "idx: {}".format(node.index)) @@ -203,13 +192,13 @@ def tuple_get_item_node( def constant_node( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: node_id = node_to_id[node] node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - node_info = Node(node_id, "Const", "\n".join(node_detail)) + node_info = Node(node_id, "Const", node_detail) edge_info = [] return node_info, edge_info @@ -220,21 +209,19 @@ def build_rules(self): self.render_rules = { tvm.relay.Function: self.function_node, tvm.relay.expr.Call: self.call_node, - tvm.relay.expr.Let: self.let_node, tvm.relay.expr.Var: self.var_node, tvm.relay.expr.Tuple: self.tuple_node, tvm.relay.expr.TupleGetItem: self.tuple_get_item_node, tvm.relay.expr.Constant: self.constant_node, - tvm.relay.expr.If: self.null, tvm.relay.expr.GlobalVar: self.null, tvm.ir.Op: self.null, } def get_node_edges( self, - node: relay.expr.ExprWithOp, + node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.expr.ExprWithOp, Union[int, str]], + node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[Node, None], List[Edge]]: try: node_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) diff --git a/tests/python/contrib/test_relay_viz.py b/tests/python/contrib/test_relay_viz.py index 0ab393cf6d9c..3e5b73eae0ec 100644 --- a/tests/python/contrib/test_relay_viz.py +++ b/tests/python/contrib/test_relay_viz.py @@ -15,38 +15,77 @@ # specific language governing permissions and limitations # under the License. -# from tvm import relay -# from tvm.contrib.relay_viz.node_edge_gen import DefaultNodeEdgeGenerator +import tvm +from tvm import relay +from tvm.contrib.relay_viz import node_edge_gen +from tvm.contrib.relay_viz.node_edge_gen import DefaultNodeEdgeGenerator + +# the testing focus on that DefaultNodeEdgeGenerator can +# parse Relay IR properly. def test_var(): - pass + ne_gen = DefaultNodeEdgeGenerator() + shape = (10, 10) + input_var = relay.var("input", shape=shape) + node, edges = ne_gen.get_node_edges(input_var, {}, {input_var: 1}) + assert node.identity == 1, "node_id should be 1." + assert "input" in node.detail, "detail should have name_hint." + assert str(shape) in node.detail, "detail should have shape." + assert len(edges) == 0, "relay.var doesn't cause any edge." def test_function(): - pass + ne_gen = DefaultNodeEdgeGenerator() + input_var = relay.var("input") + bias_var = relay.var("bias") + add_bias = relay.add(input_var, bias_var) + func = relay.Function([input_var, bias_var], add_bias) + node, edges = ne_gen.get_node_edges(func, {}, {func: 99, add_bias: 199}) + assert node.identity == 99, "node_id should be 99." + assert edges[0].start == 199, "edge.start should be node 199." + assert edges[0].end == 99, "edge.end should be node 99." def test_call(): - pass - - -def test_let(): - pass + ne_gen = DefaultNodeEdgeGenerator() + input_var = relay.var("input") + bias_var = relay.var("bias") + add_bias = relay.add(input_var, bias_var) + node, edges = ne_gen.get_node_edges(add_bias, {}, {add_bias: 1, input_var: 0, bias_var: 2}) + assert "add" in node.type_str, "node_type shuold contain op_name." + assert len(edges) == 2, "the length of edges should be 2, from two var to relay.add." def test_tuple(): - pass + ne_gen = DefaultNodeEdgeGenerator() + elemt0_var = relay.var("elemt0") + elemt1_var = relay.var("elemt1") + tup = relay.Tuple([elemt0_var, elemt1_var]) + node, edges = ne_gen.get_node_edges(tup, {}, {tup: 123, elemt0_var: 0, elemt1_var: 1}) + assert node.identity == 123, "node_id should be 123." + assert len(edges) == 2, "the length of edges should be 2, from two relay.var to tuple." + assert edges[0].start == 0 and edges[0].end == 123, "edges[0] should be 0 -> 123." + assert edges[1].start == 1 and edges[1].end == 123, "edges[1] should be 1 -> 123." def test_constant(): - pass + ne_gen = DefaultNodeEdgeGenerator() + arr = tvm.nd.array(10) + const = relay.Constant(arr) + node, edges = ne_gen.get_node_edges(const, {}, {const: 999}) + assert node.identity == 999, "node_id should be 999." + assert len(edges) == 0, "constant should not cause edges." + + arr = tvm.nd.array([[10, 11]]) + const = relay.Constant(arr) + node, edges = ne_gen.get_node_edges(const, {}, {const: 111}) + assert str(const.data.shape) in node.detail, "node_detail should contain shape." if __name__ == "__main__": test_var() test_function() test_call() - test_let() test_tuple() test_constant() From 9955d7c833d91e5d2f63dc326baedcdfcc95e316 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 25 Oct 2021 11:49:06 +0000 Subject: [PATCH 14/16] tutorial and doc --- docs/reference/api/python/contrib.rst | 10 ++ .../how_to/work_with_relay/using_relay_viz.py | 162 ++++++++++++++++++ python/tvm/contrib/relay_viz/README.md | 10 +- python/tvm/contrib/relay_viz/__init__.py | 36 ++-- .../contrib/relay_viz/{_bokeh.py => bokeh.py} | 19 +- python/tvm/contrib/relay_viz/node_edge_gen.py | 151 ++++++++-------- python/tvm/contrib/relay_viz/plotter.py | 28 ++- .../relay_viz/{_terminal.py => terminal.py} | 38 ++-- 8 files changed, 326 insertions(+), 128 deletions(-) create mode 100644 gallery/how_to/work_with_relay/using_relay_viz.py rename python/tvm/contrib/relay_viz/{_bokeh.py => bokeh.py} (98%) rename python/tvm/contrib/relay_viz/{_terminal.py => terminal.py} (85%) diff --git a/docs/reference/api/python/contrib.rst b/docs/reference/api/python/contrib.rst index 0eb3024c2d08..6cef11896753 100644 --- a/docs/reference/api/python/contrib.rst +++ b/docs/reference/api/python/contrib.rst @@ -92,6 +92,16 @@ tvm.contrib.random .. automodule:: tvm.contrib.random :members: +tvm.contrib.relay_viz +~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.relay_viz + :members: RelayVisualizer +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.BOKEH +.. autoattribute:: tvm.contrib.relay_viz.PlotterBackend.TERMINAL +.. automodule:: tvm.contrib.relay_viz.plotter + :members: +.. automodule:: tvm.contrib.relay_viz.node_edge_gen + :members: tvm.contrib.rocblas ~~~~~~~~~~~~~~~~~~~ diff --git a/gallery/how_to/work_with_relay/using_relay_viz.py b/gallery/how_to/work_with_relay/using_relay_viz.py new file mode 100644 index 000000000000..827be8ba3c0f --- /dev/null +++ b/gallery/how_to/work_with_relay/using_relay_viz.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=line-too-long +""" +Use Relay Visualizer to Visualize Relay +============================================================ +**Author**: `Chi-Wei Wang `_ + +This is an introduction about using Relay Visualizer to visualize a Relay IR module. + +Relay IR module can contain lots of operations. Although individual +operations are usually easy to understand, they become complicated quickly +when you put them together. It could get even worse while optimiztion passes +come into play. + +This utility abstracts an IR module as graphs containing nodes and edges. +It provides a default parser to interpret an IR modules with nodes and edges. +Two renderer backends are also implemented to visualize them. + +Here we use a backend showing Relay IR module in the terminal for illustation. +It is a much more lightweight compared to another backend using `Bokeh `_. +See ``/python/tvm/contrib/relay_viz/README.md``. +Also we will introduce how to implement customized parsers and renderers through +some interfaces classes. +""" +from typing import ( + Dict, + Union, + Tuple, + List, +) +import tvm +from tvm import relay +from tvm.contrib import relay_viz +from tvm.contrib.relay_viz.node_edge_gen import ( + VizNode, + VizEdge, + NodeEdgeGenerator, +) +from tvm.contrib.relay_viz.terminal import ( + TermNodeEdgeGenerator, + TermGraph, + TermPlotter, +) + +###################################################################### +# Define a Relay IR Module with multiple GlobalVar +# ------------------------------------------------ +# Let's build an example Relay IR Module containing multiple ``GlobalVar``. +# We define an ``add`` function and call it in the main function. +data = relay.var("data") +bias = relay.var("bias") +add_op = relay.add(data, bias) +add_func = relay.Function([data, bias], add_op) +add_gvar = relay.GlobalVar("AddFunc") + +input0 = relay.var("input0") +input1 = relay.var("input1") +input2 = relay.var("input2") +add_01 = relay.Call(add_gvar, [input0, input1]) +add_012 = relay.Call(add_gvar, [input2, add_01]) +main_func = relay.Function([input0, input1, input2], add_012) +main_gvar = relay.GlobalVar("main") + +mod = tvm.IRModule({main_gvar: main_func, add_gvar: add_func}) + +###################################################################### +# Render the graph with Relay Visualizer on the terminal +# ------------------------------------------------------ +# The terminal backend can show a Relay IR module as in a text-form +# similar to `clang ast-dump `_. +# We should see ``main`` and ``AddFunc`` function. ``AddFunc`` is called twice in the ``main`` function. +viz = relay_viz.RelayVisualizer(mod, {}, relay_viz.PlotterBackend.TERMINAL) +viz.render() + +###################################################################### +# Customize Parser for Interested Relay Types +# ------------------------------------------- +# Sometimes the information shown by the default implementation is not suitable +# for a specific usage. It is possible to provide your own parser and renderer. +# Here demostrate how to customize parsers for ``relay.var``. +# We need to implement :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` interface. +class YourAwesomeParser(NodeEdgeGenerator): + def __init__(self): + self._org_parser = TermNodeEdgeGenerator() + + def get_node_edges( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + + if isinstance(node, relay.Var): + node = VizNode(node_to_id[node], "AwesomeVar", f"name_hint {node.name_hint}") + # no edge is introduced. So return an empty list. + ret = (node, []) + return ret + + # delegate other types to the original parser. + return self._org_parser.get_node_edges(node, relay_param, node_to_id) + + +###################################################################### +# Pass a tuple of :py:class:`tvm.contrib.relay_viz.plotter.Plotter` and +# :py:class:`tvm.contrib.relay_viz.node_edge_gen.NodeEdgeGenerator` instances +# to ``RelayVisualizer``. Here we re-use the Plotter interface implemented inside +# ``relay_viz.terminal`` module. +viz = relay_viz.RelayVisualizer(mod, {}, (TermPlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# More Customization around Graph and Plotter +# ------------------------------------------- +# All ``RelayVisualizer`` care about are interfaces defined in ``plotter.py`` and +# ``node_edge_generator.py``. We can override them to introduce custimized logics. +# For example, if we want the Graph to duplicate above ``AwesomeVar`` while it is added, +# we can override ``relay_viz.terminal.TermGraph.node``. +class AwesomeGraph(TermGraph): + def node(self, node_id, node_type, node_detail): + # add original node first + super().node(node_id, node_type, node_detail) + if node_type == "AwesomeVar": + duplicated_id = f"duplciated_{node_id}" + duplicated_type = "double AwesomeVar" + super().node(duplicated_id, duplicated_type, "") + # connect the duplicated var to the original one + super().edge(duplicated_id, node_id) + + +# override TermPlotter to return `AwesomeGraph` instead +class AwesomePlotter(TermPlotter): + def create_graph(self, name): + self._name_to_graph[name] = AwesomeGraph(name) + return self._name_to_graph[name] + + +viz = relay_viz.RelayVisualizer(mod, {}, (AwesomePlotter(), YourAwesomeParser())) +viz.render() + +###################################################################### +# Summary +# ------- +# This tutorial demonstrates the usage of Relay Visualizer. +# The class :py:class:`tvm.contrib.relay_viz.RelayVisualizer` is composed of interfaces +# defined in ``plotter.py`` and ``node_edge_generator.py``. It provides a single entry point +# while keeping the possibility of implementing customized visualizer in various cases. +# diff --git a/python/tvm/contrib/relay_viz/README.md b/python/tvm/contrib/relay_viz/README.md index bb6e964e8f07..c1d7d2245249 100644 --- a/python/tvm/contrib/relay_viz/README.md +++ b/python/tvm/contrib/relay_viz/README.md @@ -28,6 +28,10 @@ This tool target to visualize Relay IR. ## Requirement +### Terminal Backend +1. TVM + +### Bokeh Backend 1. TVM 2. graphviz 2. pydot @@ -66,9 +70,9 @@ This utility is composed of two parts: `node_edge_gen.py` and `plotter.py`. `plotter.py` define interfaces of `Graph` and `Plotter`. `Plotter` is responsible to render a collection of `Graph`. -`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes/edges consumed by `Graph`. Further, this python module also provide a default implementation for common relay types. +`node_edge_gen.py` define interfaces of converting Relay IR modules to nodes and edges. Further, this python module provide a default implementation for common relay types. If customization is wanted for a certain relay type, we can implement the `NodeEdgeGenerator` interface, handling that relay type accordingly, and delegate other types to the default implementation. See `_terminal.py` for an example usage. -These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes/edges to `Graph`. -Then, it render the plot by `Plotter.render()`. +These two interfaces are glued by the top level class `RelayVisualizer`, which passes a relay module to `NodeEdgeGenerator` and add nodes and edges to `Graph`. +Then, it render the plot by calling `Plotter.render()`. diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 855adeb2e8ee..1015f4781f6f 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -28,14 +28,26 @@ class PlotterBackend(Enum): - """Enumeration for available plotters.""" + """Enumeration for available plotter backends.""" BOKEH = "bokeh" TERMINAL = "terminal" class RelayVisualizer: - """Relay IR Visualizer""" + """Relay IR Visualizer + + Parameters + ---------- + relay_mod : tvm.IRModule + Relay IR module. + relay_param: None | Dict[str, tvm.runtime.NDArray] + Relay parameter dictionary. Default `None`. + backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator] + The backend used to render graphs. It can be a tuple of an implemented Plotter instance and + NodeEdgeGenerator instance to introduce customized parsing and visualization logics. + Default ``PlotterBackend.TERMINAL``. + """ def __init__( self, @@ -43,14 +55,6 @@ def __init__( relay_param: Union[None, Dict[str, tvm.runtime.NDArray]] = None, backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] = PlotterBackend.TERMINAL, ): - """Visualize Relay IR. - - Parameters - ---------- - relay_mod : tvm.IRModule, Relay IR module - relay_param: None | Dict[str, tvm.runtime.NDArray], Relay parameter dictionary. Default `None`. - backend: PlotterBackend | Tuple[Plotter, NodeEdgeGenerator], Default `PlotterBackend.TERMINAL`. - """ self._plotter, self._ne_generator = get_plotter_and_generator(backend) self._relay_param = relay_param if relay_param is not None else {} @@ -83,8 +87,10 @@ def _add_nodes(self, graph, node_to_id, relay_param): Parameters ---------- - graph : `plotter.Graph` + graph : plotter.Graph + node_to_id : Dict[relay.expr, str | int] + relay_param : Dict[str, tvm.runtime.NDarray] """ for node in node_to_id: @@ -102,11 +108,11 @@ def get_plotter_and_generator(backend): """Specify the Plottor and its NodeEdgeGenerator""" if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): - raise ValueError(f"First element of backend should be derived from {type(Plotter)}") + raise ValueError(f"First element should be an instance derived from {type(Plotter)}") if not isinstance(backend[1], NodeEdgeGenerator): raise ValueError( - f"Second element of backend should be derived from {type(NodeEdgeGenerator)}" + f"Second element should be an instance derived from {type(NodeEdgeGenerator)}" ) return backend @@ -118,7 +124,7 @@ def get_plotter_and_generator(backend): # Basically we want to keep them optional. Users can choose plotters they want to install. if backend == PlotterBackend.BOKEH: # pylint: disable=import-outside-toplevel - from ._bokeh import ( + from .bokeh import ( BokehPlotter, BokehNodeEdgeGenerator, ) @@ -127,7 +133,7 @@ def get_plotter_and_generator(backend): ne_generator = BokehNodeEdgeGenerator() elif backend == PlotterBackend.TERMINAL: # pylint: disable=import-outside-toplevel - from ._terminal import ( + from .terminal import ( TermPlotter, TermNodeEdgeGenerator, ) diff --git a/python/tvm/contrib/relay_viz/_bokeh.py b/python/tvm/contrib/relay_viz/bokeh.py similarity index 98% rename from python/tvm/contrib/relay_viz/_bokeh.py rename to python/tvm/contrib/relay_viz/bokeh.py index 58716510c91a..6ea82188463e 100644 --- a/python/tvm/contrib/relay_viz/_bokeh.py +++ b/python/tvm/contrib/relay_viz/bokeh.py @@ -19,22 +19,9 @@ import functools import logging -_LOGGER = logging.getLogger(__name__) - import numpy as np - -try: - import pydot -except ImportError: - _LOGGER.critical("pydot library is required. You might want to run pip install pydot.") - raise - -try: - from bokeh.io import output_file, save -except ImportError: - _LOGGER.critical("bokeh library is required. You might want to run pip install bokeh.") - raise - +import pydot +from bokeh.io import output_file, save from bokeh.models import ( ColumnDataSource, CustomJS, @@ -63,6 +50,8 @@ from .node_edge_gen import DefaultNodeEdgeGenerator +_LOGGER = logging.getLogger(__name__) + # Use default node/edge generator BokehNodeEdgeGenerator = DefaultNodeEdgeGenerator diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index ffe54e741897..aaeea1d9ffd9 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""NodeEdgeGenerator interface""" +"""NodeEdgeGenerator interface for :py:class:`tvm.contrib.relay_viz.plotter.Graph`.""" import abc from typing import ( Dict, @@ -28,7 +28,7 @@ UNKNOWN_TYPE = "unknown" -class Node: +class VizNode: """Node carry information used by `plotter.Graph` interface.""" def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): @@ -49,8 +49,8 @@ def detail(self) -> str: return self._detail -class Edge: - """Edge for `plotter.Graph` interface.""" +class VizEdge: + """Edges for `plotter.Graph` interface.""" def __init__(self, start_node: Union[int, str], end_node: Union[int, str]): self._start_node = start_node @@ -74,10 +74,29 @@ def get_node_edges( node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: - """Generate node and edges consumed by Graph interfaces - The returned tuple containing Node and a list of Edge instances. - Tuple[None, list[]] for null results. + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + """Generate node and edges consumed by Graph interfaces. + + Parameters + ---------- + node : relay.Expr + relay.Expr which will be parsed and generate a node and edges. + + relay_param: Dict[str, tvm.runtime.NDArray] + relay parameters dictionary. + + node_to_id : Dict[relay.Expr, Union[int, str]] + a mapping from relay.Expr to node id which should be unique. + + Returns + ------- + rv1 : Union[VizNode, None] + VizNode represent the relay.Expr. If the relay.Expr is not intended to introduce a node + to the graph, return None. + + rv2 : List[VizEdge] + a list of VizEdge to describe the connectivity of the relay.Expr. + Can be empty list to indicate no connectivity. """ @@ -88,59 +107,72 @@ class DefaultNodeEdgeGenerator(NodeEdgeGenerator): """ def __init__(self): - self.render_rules = {} - self.build_rules() + self._render_rules = {} + self._build_rules() - def var_node( + def get_node_edges( self, node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: + try: + node_info, edge_info = self._render_rules[type(node)](node, relay_param, node_to_id) + except KeyError: + node_info = VizNode( + node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" + ) + edge_info = [] + return node_info, edge_info + + def _var_node( + self, + node: relay.Expr, + relay_param: Dict[str, tvm.runtime.NDArray], + node_to_id: Dict[relay.Expr, Union[int, str]], + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay var node""" node_id = node_to_id[node] name_hint = node.name_hint - node_detail = "" + node_detail = f"name_hint: {name_hint}" node_type = "Var(Param)" if name_hint in relay_param else "Var(Input)" if node.type_annotation is not None: if hasattr(node.type_annotation, "shape"): shape = tuple(map(int, node.type_annotation.shape)) dtype = node.type_annotation.dtype - node_detail = "name_hint: {}\nshape: {}\ndtype: {}".format(name_hint, shape, dtype) + node_detail = f"name_hint: {name_hint}\nshape: {shape}\ndtype: {dtype}" else: - node_detail = "name_hint: {}\ntype_annotation: {}".format( - name_hint, node.type_annotation - ) - node_info = Node(node_id, node_type, node_detail) + node_detail = f"name_hint: {name_hint}\ntype_annotation: {node.type_annotation}" + node_info = VizNode(node_id, node_type, node_detail) edge_info = [] return node_info, edge_info - def function_node( + def _function_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay function node""" node_details = [] name = "" func_attrs = node.attrs if func_attrs: - node_details = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] + node_details = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): name = func_attrs["Composite"] node_id = node_to_id[node] - node_info = Node(node_id, f"Func {name}", "\n".join(node_details)) - edge_info = [Edge(node_to_id[node.body], node_id)] + node_info = VizNode(node_id, f"Func {name}", "\n".join(node_details)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] return node_info, edge_info - def call_node( + def _call_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay call node""" node_id = node_to_id[node] op_name = UNKNOWN_TYPE @@ -148,12 +180,12 @@ def call_node( if isinstance(node.op, tvm.ir.Op): op_name = node.op.name if node.attrs: - node_detail = ["{}: {}".format(k, node.attrs.get_str(k)) for k in node.attrs.keys()] + node_detail = [f"{k}: {node.attrs.get_str(k)}" for k in node.attrs.keys()] elif isinstance(node.op, relay.Function): func_attrs = node.op.attrs op_name = "Anonymous Func" if func_attrs: - node_detail = ["{}: {}".format(k, func_attrs.get_str(k)) for k in func_attrs.keys()] + node_detail = [f"{k}: {func_attrs.get_str(k)}" for k in func_attrs.keys()] # "Composite" might from relay.transform.MergeComposite if "Composite" in func_attrs.keys(): op_name = func_attrs["Composite"] @@ -163,71 +195,56 @@ def call_node( else: op_name = str(type(node.op)).split(".")[-1].split("'")[0] - node_info = Node(node_id, f"Call {op_name}", "\n".join(node_detail)) + node_info = VizNode(node_id, f"Call {op_name}", "\n".join(node_detail)) args = [node_to_id[arg] for arg in node.args] - edge_info = [Edge(arg, node_id) for arg in args] + edge_info = [VizEdge(arg, node_id) for arg in args] return node_info, edge_info - def tuple_node( + def _tuple_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_info = Node(node_id, "Tuple", "") - edge_info = [Edge(node_to_id[field], node_id) for field in node.fields] + node_info = VizNode(node_id, "Tuple", "") + edge_info = [VizEdge(node_to_id[field], node_id) for field in node.fields] return node_info, edge_info - def tuple_get_item_node( + def _tuple_get_item_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_info = Node(node_id, "TupleGetItem", "idx: {}".format(node.index)) - edge_info = [Edge(node_to_id[node.tuple_value], node_id)] + node_info = VizNode(node_id, f"TupleGetItem", "idx: {node.index}") + edge_info = [VizEdge(node_to_id[node.tuple_value], node_id)] return node_info, edge_info - def constant_node( + def _constant_node( self, node: relay.Expr, _: Dict[str, tvm.runtime.NDArray], # relay_param node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] - node_detail = "shape: {}, dtype: {}".format(node.data.shape, node.data.dtype) - node_info = Node(node_id, "Const", node_detail) + node_detail = f"shape: {node.data.shape}, dtype: {node.data.dtype}" + node_info = VizNode(node_id, "Const", node_detail) edge_info = [] return node_info, edge_info - def null(self, *_) -> Tuple[None, List[Edge]]: + def _null(self, *_) -> Tuple[None, List[VizEdge]]: return None, [] - def build_rules(self): - self.render_rules = { - tvm.relay.Function: self.function_node, - tvm.relay.expr.Call: self.call_node, - tvm.relay.expr.Var: self.var_node, - tvm.relay.expr.Tuple: self.tuple_node, - tvm.relay.expr.TupleGetItem: self.tuple_get_item_node, - tvm.relay.expr.Constant: self.constant_node, - tvm.relay.expr.GlobalVar: self.null, - tvm.ir.Op: self.null, + def _build_rules(self): + self._render_rules = { + tvm.relay.Function: self._function_node, + tvm.relay.expr.Call: self._call_node, + tvm.relay.expr.Var: self._var_node, + tvm.relay.expr.Tuple: self._tuple_node, + tvm.relay.expr.TupleGetItem: self._tuple_get_item_node, + tvm.relay.expr.Constant: self._constant_node, + tvm.relay.expr.GlobalVar: self._null, + tvm.ir.Op: self._null, } - - def get_node_edges( - self, - node: relay.Expr, - relay_param: Dict[str, tvm.runtime.NDArray], - node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: - try: - node_info, edge_info = self.render_rules[type(node)](node, relay_param, node_to_id) - except KeyError: - node_info = Node( - node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" - ) - edge_info = [] - return node_info, edge_info diff --git a/python/tvm/contrib/relay_viz/plotter.py b/python/tvm/contrib/relay_viz/plotter.py index 58ed2c02ebd9..de8c24c39a40 100644 --- a/python/tvm/contrib/relay_viz/plotter.py +++ b/python/tvm/contrib/relay_viz/plotter.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Abstract class for plotters.""" +"""Abstract class used by :py:class:`tvm.contrib.relay_viz.RelayVisualizer`.""" import abc from typing import Union @@ -28,9 +28,14 @@ def node(self, node_id: Union[int, str], node_type: str, node_detail: str) -> No Parameters ---------- - node_id : Union[int, str], Serve as the ID to the node. - node_type : str, the type of the node. - node_detail : str, the description of the node. + node_id : Union[int, str] + Serve as the ID to the node. + + node_type : str + the type of the node. + + node_detail : str + the description of the node. """ @abc.abstractmethod @@ -39,8 +44,11 @@ def edge(self, id_start: Union[int, str], id_end: Union[int, str]) -> None: Parameters ---------- - id_start : Union[int, str], the ID to the starting node. - id_end : Union[int, str], the ID to the ending node. + id_start : Union[int, str] + the ID to the starting node. + + id_end : Union[int, str] + the ID to the ending node. """ @@ -53,11 +61,12 @@ def create_graph(self, name: str) -> Graph: Parameters ---------- - name : string, the name of the graph + name : str + the name of the graph Return ------ - Graph instance. + rv1: class Graph """ @abc.abstractmethod @@ -66,5 +75,6 @@ def render(self, filename: str) -> None: Parameters ---------- - filename : string, see the definition of implemented class. + filename : str + see the definition of implemented class. """ diff --git a/python/tvm/contrib/relay_viz/_terminal.py b/python/tvm/contrib/relay_viz/terminal.py similarity index 85% rename from python/tvm/contrib/relay_viz/_terminal.py rename to python/tvm/contrib/relay_viz/terminal.py index 230bd88f5a52..82332d964907 100644 --- a/python/tvm/contrib/relay_viz/_terminal.py +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Visualize Relay IR in AST text-form""" +"""Visualize Relay IR in AST text-form.""" from collections import deque from typing import ( @@ -33,8 +33,8 @@ ) from .node_edge_gen import ( - Node, - Edge, + VizNode, + VizEdge, NodeEdgeGenerator, DefaultNodeEdgeGenerator, ) @@ -51,7 +51,7 @@ def get_node_edges( node: relay.Expr, relay_param: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], - ) -> Tuple[Union[Node, None], List[Edge]]: + ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Generate node and edges consumed by TermGraph interfaces""" if isinstance(node, relay.Call): return self._call_node(node, node_to_id) @@ -76,50 +76,50 @@ def get_node_edges( def _call_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Call", "") - edge_info = [Edge(node_to_id[node.op], node_id)] + node_info = VizNode(node_id, "Call", "") + edge_info = [VizEdge(node_to_id[node.op], node_id)] for arg in node.args: arg_nid = node_to_id[arg] - edge_info.append(Edge(arg_nid, node_id)) + edge_info.append(VizEdge(arg_nid, node_id)) return node_info, edge_info def _let_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Let", "(var, val, body)") + node_info = VizNode(node_id, "Let", "(var, val, body)") edge_info = [ - Edge(node_to_id[node.var], node_id), - Edge(node_to_id[node.value], node_id), - Edge(node_to_id[node.body], node_id), + VizEdge(node_to_id[node.var], node_id), + VizEdge(node_to_id[node.value], node_id), + VizEdge(node_to_id[node.body], node_id), ] return node_info, edge_info def _global_var_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "GlobalVar", node.name_hint) + node_info = VizNode(node_id, "GlobalVar", node.name_hint) edge_info = [] return node_info, edge_info def _if_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "If", "(cond, true, false)") + node_info = VizNode(node_id, "If", "(cond, true, false)") edge_info = [ - Edge(node_to_id[node.cond], node_id), - Edge(node_to_id[node.true_branch], node_id), - Edge(node_to_id[node.false_branch], node_id), + VizEdge(node_to_id[node.cond], node_id), + VizEdge(node_to_id[node.true_branch], node_id), + VizEdge(node_to_id[node.false_branch], node_id), ] return node_info, edge_info def _op_node(self, node, node_to_id): node_id = node_to_id[node] op_name = node.name - node_info = Node(node_id, op_name, "") + node_info = VizNode(node_id, op_name, "") edge_info = [] return node_info, edge_info def _function_node(self, node, node_to_id): node_id = node_to_id[node] - node_info = Node(node_id, "Func", str(node.params)) - edge_info = [Edge(node_to_id[node.body], node_id)] + node_info = VizNode(node_id, "Func", str(node.params)) + edge_info = [VizEdge(node_to_id[node.body], node_id)] return node_info, edge_info From 6fcd64c718989f62a515123d1beb298b7a55ef16 Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 6 Dec 2021 15:29:06 +0000 Subject: [PATCH 15/16] add doc-string according to feedback --- python/tvm/contrib/relay_viz/__init__.py | 16 +++++++--- python/tvm/contrib/relay_viz/node_edge_gen.py | 32 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 1015f4781f6f..6c84a5a1fd07 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -39,7 +39,7 @@ class RelayVisualizer: Parameters ---------- - relay_mod : tvm.IRModule + relay_mod: tvm.IRModule Relay IR module. relay_param: None | Dict[str, tvm.runtime.NDArray] Relay parameter dictionary. Default `None`. @@ -104,8 +104,16 @@ def render(self, filename: str = None) -> None: self._plotter.render(filename=filename) -def get_plotter_and_generator(backend): - """Specify the Plottor and its NodeEdgeGenerator""" +def get_plotter_and_generator( + backend: Union[PlotterBackend, Tuple[Plotter, NodeEdgeGenerator]] +) -> Tuple[Plotter, NodeEdgeGenerator]: + """Specify the Plottor and its NodeEdgeGenerator + + Parameters + ---------- + backend : PlotterBackend | Tuple[Plotter, NodeEdgeGenerator] + Backend used to generate nodes/edges and render them. + """ if isinstance(backend, (tuple, list)) and len(backend) == 2: if not isinstance(backend[0], Plotter): raise ValueError(f"First element should be an instance derived from {type(Plotter)}") @@ -115,7 +123,7 @@ def get_plotter_and_generator(backend): f"Second element should be an instance derived from {type(NodeEdgeGenerator)}" ) - return backend + return tuple(backend) if backend not in PlotterBackend: raise ValueError(f"Unknown plotter backend {backend}") diff --git a/python/tvm/contrib/relay_viz/node_edge_gen.py b/python/tvm/contrib/relay_viz/node_edge_gen.py index aaeea1d9ffd9..cc60266a4c2f 100644 --- a/python/tvm/contrib/relay_viz/node_edge_gen.py +++ b/python/tvm/contrib/relay_viz/node_edge_gen.py @@ -29,7 +29,17 @@ class VizNode: - """Node carry information used by `plotter.Graph` interface.""" + """Node carry information used by `plotter.Graph` interface. + + Parameters + ---------- + node_id: int | str + Unique identifier for this node. + node_type: str + Type of this node. + node_detail: str + Any supplement for this node such as attributes. + """ def __init__(self, node_id: Union[int, str], node_type: str, node_detail: str): self._id = node_id @@ -50,7 +60,15 @@ def detail(self) -> str: class VizEdge: - """Edges for `plotter.Graph` interface.""" + """Edges for `plotter.Graph` interface. + + Parameters + ---------- + start_node: int | str + The identifier of the node starting the edge. + end_node: int | str + The identifier of the node ending the edge. + """ def __init__(self, start_node: Union[int, str], end_node: Union[int, str]): self._start_node = start_node @@ -150,7 +168,7 @@ def _var_node( def _function_node( self, node: relay.Expr, - _: Dict[str, tvm.runtime.NDArray], # relay_param + _: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay function node""" @@ -170,7 +188,7 @@ def _function_node( def _call_node( self, node: relay.Expr, - _: Dict[str, tvm.runtime.NDArray], # relay_param + _: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: """Render rule for a relay call node""" @@ -203,7 +221,7 @@ def _call_node( def _tuple_node( self, node: relay.Expr, - _: Dict[str, tvm.runtime.NDArray], # relay_param + _: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] @@ -214,7 +232,7 @@ def _tuple_node( def _tuple_get_item_node( self, node: relay.Expr, - _: Dict[str, tvm.runtime.NDArray], # relay_param + _: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] @@ -225,7 +243,7 @@ def _tuple_get_item_node( def _constant_node( self, node: relay.Expr, - _: Dict[str, tvm.runtime.NDArray], # relay_param + _: Dict[str, tvm.runtime.NDArray], node_to_id: Dict[relay.Expr, Union[int, str]], ) -> Tuple[Union[VizNode, None], List[VizEdge]]: node_id = node_to_id[node] From bcca757e8cb2ceeb29f5d5e780fd75238a1cf9dc Mon Sep 17 00:00:00 2001 From: chiwwang Date: Mon, 6 Dec 2021 15:55:39 +0000 Subject: [PATCH 16/16] fix lint...? --- src/target/target_kind.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 5540c35a8f7e..b467c58fa8f7 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -402,7 +402,6 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option>("devices"); - /********** Registry **********/ TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);