Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Preserve Pytorch Span Names #16171

Merged
merged 10 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 100 additions & 19 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"""PT: PyTorch frontend."""
import functools
import itertools
from abc import ABC
from typing import Dict
import math
import re
import sys
Expand Down Expand Up @@ -137,7 +139,9 @@ def _is_int_seq(seq):
class PyTorchOpConverter:
"""A helper class for holding PyTorch op converters."""

def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
def __init__(
self, prelude, default_dtype, use_parser_friendly_name=False, preserve_pytorch_scopes=False
):
self.prelude = prelude
self.default_dtype = default_dtype
self.create_convert_map()
Expand All @@ -146,6 +150,7 @@ def __init__(self, prelude, default_dtype, use_parser_friendly_name=False):
self.op_type_dict = {} # map from op type to its presenting order
self.current_op = [] # stack for recording current processing op
self.use_parser_friendly_name = use_parser_friendly_name
self.preserve_pytorch_scopes = preserve_pytorch_scopes

# this incrementally infers the type, see the comments on the type visitor
# above.
Expand Down Expand Up @@ -4204,7 +4209,11 @@ def report_missing_conversion(self, op_names):
def convert_block(self, block, outputs):
"""Translate Torch "Block", used for prim::If and prim::Loop"""
ops = _get_operator_nodes(
block.nodes(), self.source_map, self.op_type_dict, self.use_parser_friendly_name
block.nodes(),
self.source_map,
self.op_type_dict,
self.use_parser_friendly_name,
self.preserve_pytorch_scopes,
)
ret_names = _get_input_names(block.returnNode())
return self.convert_operators(ops, outputs, ret_names)
Expand Down Expand Up @@ -4771,33 +4780,84 @@ def _get_constant(node):
return None


def _rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name):
"""Rewrite debug name of node outputs with its operator type"""
class NodeNamer(ABC):
navya-encharge marked this conversation as resolved.
Show resolved Hide resolved
"""Name each node and output edge in the relay graph"""

def _get_source_name(op_type):
def __init__(self, op_counter_dict: Dict[str, int]):
self.op_counter_dict = op_counter_dict

def increment_counter(self, identifier: str) -> int:
op_idx = 0
if op_type in op_type_dict:
op_idx = op_type_dict[op_type] + 1
op_type_dict[op_type] = op_idx
return "_".join([op_type, str(op_idx)])
if identifier in self.op_counter_dict:
op_idx = self.op_counter_dict[identifier] + 1
self.op_counter_dict[identifier] = op_idx
return op_idx

# get source name of operator and rename all of its outputs
def get_node_source_name(self, node) -> str:
raise NotImplementedError()

def get_node_output_name(self, node_src_name: str, index: int) -> str:
raise NotImplementedError()


class DefaultNodeKindNamer(NodeNamer):
"""
Namer that uses a default naming based on the "type"/kind of node
# e.g. node.kind(): aten::adaptive_max_pool2d
# node_src_name -> aten::adaptive_max_pool2d_x
# output_1 -> aten::adaptive_max_pool2d_x_0
# output_2 -> aten::adaptive_max_pool2d_x_1
"""

def get_node_source_name(self, node) -> str:
op_idx = self.increment_counter(node.kind())
return "_".join([node.kind(), str(op_idx)])

def get_node_output_name(self, node_src_name: str, index: int) -> str:
return "_".join([node_src_name, str(index)])


class PytorchScopePreservingNamer(NodeNamer):
navya-encharge marked this conversation as resolved.
Show resolved Hide resolved
"""
Namer that uses the Pytorch scope to name nodes.
eg. node could be called "bert.encoder.layer.11.output.dense"
"""

def get_node_source_name(self, node) -> str:
# This works per the scope naming in Pytorch 2.0 and beyond.
scope_name_parts = node.scopeName().split("/")
imp_parts = [part.split("::")[-1] for part in scope_name_parts]
node_src_name = ".".join([part for part in imp_parts if part])
return node_src_name

def get_node_output_name(self, node_src_name: str, index: int) -> str:
op_idx = self.increment_counter(node_src_name)
return "_".join([node_src_name, str(op_idx), str(index)])


def _rename_outputs(
node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes
):
"""Rewrite debug name of node outputs with its operator type"""
namer = (
PytorchScopePreservingNamer(op_type_dict)
if preserve_pytorch_scopes
else DefaultNodeKindNamer(op_type_dict)
)
# get source name of operator and rename all of its outputs
if node.kind() != "prim::GetAttr":
node_src_name = _get_source_name(node.kind())
node_src_name = namer.get_node_source_name(node)
for index, output in enumerate(node.outputs()):
output.setDebugName("_".join([node_src_name, str(index)]))
name = namer.get_node_output_name(node_src_name, index)
output.setDebugName(name)
# update source map
# if use_parser_friendly_name is True: e.g. prim::Constant_0 -> prim__Constant_0
if use_parser_friendly_name:
node_src_name = re.sub(r":|\.", "_", node_src_name)
source_map[node] = node_src_name


def _debug_rename(graph, use_parser_friendly_name):
def _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes):
"""Returns map between node and source name"""
source_map, op_type_dict = {}, {}
prim_with_blocks = ["prim::If", "prim::Loop"]
Expand All @@ -4809,13 +4869,21 @@ def _traverse_graph(nodes):
if node.kind() in prim_with_blocks:
for block in node.blocks():
_traverse_graph(block.nodes())
_rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name)
_rename_outputs(
node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes
)

_traverse_graph(graph.nodes())
return source_map


def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_friendly_name=False):
def _get_operator_nodes(
nodes,
source_map=None,
op_type_dict=None,
use_parser_friendly_name=False,
preserve_pytorch_scopes=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be set to True by default, can be done as a follow up.

):
"""Returns torch IR nodes that need conversion to Relay"""
ops, should_rename_graph = [], all([source_map, op_type_dict]) is not None

Expand All @@ -4825,7 +4893,9 @@ def _get_operator_nodes(nodes, source_map=None, op_type_dict=None, use_parser_fr
continue

if should_rename_graph:
_rename_outputs(node, source_map, op_type_dict, use_parser_friendly_name)
_rename_outputs(
node, source_map, op_type_dict, use_parser_friendly_name, preserve_pytorch_scopes
)

if node.outputsSize() > 1:
node_name = "_".join(_get_output_names(node))
Expand Down Expand Up @@ -5080,6 +5150,7 @@ def from_pytorch(
use_parser_friendly_name=False,
keep_quantized_weight=False,
export_renamed_c_graph_path=None,
preserve_pytorch_scopes=False,
):
"""Load PyTorch model in the form of a scripted PyTorch model and convert into relay.
The companion parameters will be handled automatically.
Expand Down Expand Up @@ -5127,6 +5198,10 @@ def from_pytorch(
During the conversion, variable names in torch._C.Graph will be assigned based on their op
types. The exported text file can be the reference to spans.

preserve_pytorch_scopes : bool
When naming the nodes in the Relay graph, use the "scope name" from the Pytorch model.
If false, a default namer is used that does not preserve the Pytorch scope names.

Returns
-------
mod : tvm.IRModule
Expand All @@ -5141,7 +5216,9 @@ def from_pytorch(
prelude = Prelude(mod)
enable_lower_all_tuples = True

converter = PyTorchOpConverter(prelude, default_dtype, use_parser_friendly_name)
converter = PyTorchOpConverter(
prelude, default_dtype, use_parser_friendly_name, preserve_pytorch_scopes
)

graph = script_module.graph.copy()

Expand Down Expand Up @@ -5173,7 +5250,7 @@ def from_pytorch(

# rename _C.Graph here for constructing meaningful source name of graph nodes
# by doing so, we could Use source_map as the reference to rename model parameters
source_map = _debug_rename(graph, use_parser_friendly_name)
source_map = _debug_rename(graph, use_parser_friendly_name, preserve_pytorch_scopes)
param_vars, tensors, packed_param_map, param_debug_name_map = convert_params(
graph, params, source_map, use_parser_friendly_name
)
Expand Down Expand Up @@ -5201,7 +5278,11 @@ def from_pytorch(
converter.update_convert_map(qnn_torch.convert_map)

operator_nodes = _get_operator_nodes(
graph.nodes(), converter.source_map, converter.op_type_dict, use_parser_friendly_name
graph.nodes(),
converter.source_map,
converter.op_type_dict,
use_parser_friendly_name,
preserve_pytorch_scopes,
)
ret_name = _get_input_names(graph.return_node())
outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
Expand Down
106 changes: 106 additions & 0 deletions tests/python/frontend/pytorch/test_span_naming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
# pylint: disable=missing-function-docstring, redefined-builtin, use-implicit-booleaness-not-comparison
"""Tests to ensure span names are correctly populated when importing Pytorch"""
from torch import nn
import torch
import tvm


class NestedConvModule(nn.Module):
"""Module that performs Conv2d and relu activation"""

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.relu = nn.ReLU()

def forward(self, x):
x = self.relu(self.conv(x))
return x


class NestedFinalModule(nn.Module):
"""Simple module that adds 2 inputs"""

def forward(self, x, y):
return x + y


class SimpleTwoConvModule(nn.Module):
"""
ML model that performs 2 convolutions and adds them together.
All operations are inside nested modules to make scope names interesting.
"""

def __init__(self):
super().__init__()
# First convolutional module
self.image_block1 = NestedConvModule(in_channels=3, out_channels=64)
# Second convolutional module
self.image_block2 = NestedConvModule(in_channels=64, out_channels=64)
self.final_block = NestedFinalModule()

def forward(self, x):
# Forward pass through the first convolutional module
x1 = self.image_block1(x)
# Forward pass through the second convolutional module
x2 = self.image_block2(x1)
# Add the outputs of the two convolutional modules
return self.final_block(x1, x2)


def test_pytorch_scope_based_span_names():
model = SimpleTwoConvModule()
sample_input = torch.zeros((1, 3, 64, 64), dtype=torch.float32)
with torch.no_grad():
traced_torch_model = torch.jit.trace(model, sample_input)
import_input = [("model_input", (1, 3, 64, 64))]
relay_model_ir, relay_model_params = tvm.relay.frontend.from_pytorch(
traced_torch_model, import_input, preserve_pytorch_scopes=True
)
# If specified, we are preserving the pytorch named spans
for block in [1, 2]:
for key in ["weight", "bias"]:
assert f"image_block{block}.conv.{key}" in relay_model_params.keys()
# Manually check all span names since asserting structural equality is not sufficient
current_call = relay_model_ir["main"].body
assert current_call.op.name == "add"
assert current_call.span is not None and current_call.span.source_name.name == "final_block"
current_call = current_call.args[1]
for block in [2, 1]:
assert current_call.op.name == "nn.relu"
assert (
current_call.span is not None
and current_call.span.source_name.name == f"image_block{block}.relu"
)
current_call = current_call.args[0]
assert current_call.op.name == "nn.bias_add"
assert (
current_call.span is not None
and current_call.span.source_name.name == f"image_block{block}.conv"
)
current_call = current_call.args[0]
assert current_call.op.name == "nn.conv2d"
assert (
current_call.span is not None
and current_call.span.source_name.name == f"image_block{block}.conv"
)
current_call = current_call.args[0]
Loading