diff --git a/python/tvm/contrib/relay_viz/__init__.py b/python/tvm/contrib/relay_viz/__init__.py index 78bbe0e594932..32814b577d0d5 100644 --- a/python/tvm/contrib/relay_viz/__init__.py +++ b/python/tvm/contrib/relay_viz/__init__.py @@ -61,11 +61,11 @@ def __init__( graph_names = [] # If we have main function, put it to the first. # Then main function can be shown on the top. - for gv in global_vars: - if gv.name_hint == "main": - graph_names.insert(0, gv.name_hint) + for gv_node in global_vars: + if gv_node.name_hint == "main": + graph_names.insert(0, gv_node.name_hint) else: - graph_names.append(gv.name_hint) + graph_names.append(gv_node.name_hint) node_to_id = {} # callback to generate an unique string-ID for nodes. diff --git a/python/tvm/contrib/relay_viz/interface.py b/python/tvm/contrib/relay_viz/interface.py index 5802574aa7b41..95beb4f32598c 100644 --- a/python/tvm/contrib/relay_viz/interface.py +++ b/python/tvm/contrib/relay_viz/interface.py @@ -158,25 +158,25 @@ def get_node_edges( ) -> Tuple[Union[VizNode, None], List[VizEdge]]: if isinstance(node, relay.Function): return self._function(node, node_to_id) - elif isinstance(node, relay.expr.Call): + if isinstance(node, relay.expr.Call): return self._call(node, node_to_id) - elif isinstance(node, relay.expr.Var): + if isinstance(node, relay.expr.Var): return self._var(node, relay_param, node_to_id) - elif isinstance(node, relay.expr.Tuple): + if isinstance(node, relay.expr.Tuple): return self._tuple(node, node_to_id) - elif isinstance(node, relay.expr.TupleGetItem): + if isinstance(node, relay.expr.TupleGetItem): return self._tuple_get_item(node, node_to_id) - elif isinstance(node, relay.expr.Constant): + if isinstance(node, relay.expr.Constant): return self._constant(node, node_to_id) # GlobalVar possibly mean another global relay function, # which is expected to in "Graph" level, not in "Node" level. - elif isinstance(node, (relay.expr.GlobalVar, tvm.ir.Op)): + if isinstance(node, (relay.expr.GlobalVar, tvm.ir.Op)): return None, [] - else: - viz_node = VizNode( - node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" - ) - viz_edges = [] + + viz_node = VizNode( + node_to_id[node], UNKNOWN_TYPE, f"don't know how to parse {type(node)}" + ) + viz_edges = [] return viz_node, viz_edges def _var( diff --git a/python/tvm/contrib/relay_viz/terminal.py b/python/tvm/contrib/relay_viz/terminal.py index 522e382ebfdf4..7b72d9da43333 100644 --- a/python/tvm/contrib/relay_viz/terminal.py +++ b/python/tvm/contrib/relay_viz/terminal.py @@ -131,7 +131,6 @@ class TermGraph(VizGraph): """ def __init__(self, name: str): - # node_id: [ connected node_id] self._name = name # A graph in adjacency list form. # The key is source node, and the value is a list of destination nodes. @@ -171,10 +170,10 @@ def edge(self, viz_edge: VizEdge) -> None: A `VizEdge` instance. """ # Take CallNode as an example, instead of "arguments point to CallNode", - # we want "CallNode points to arguments" here. + # we want "CallNode points to arguments" in ast-dump form. # # The direction of edge is typically controlled by the implemented VizParser. - # We need reversion here simply because we leverage default parser implementation. + # Reverse start/end here simply because we leverage default parser implementation. if viz_edge.end in self._graph: self._graph[viz_edge.end].append(viz_edge.start) else: