Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Span Filling ONNX #13767

Merged
merged 1 commit into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 124 additions & 9 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
shape_of,
try_resolve_var_to_const,
unbind,
set_span,
)

__all__ = ["from_onnx"]
Expand Down Expand Up @@ -556,6 +557,37 @@ def layer_norm(x, eps, gamma, beta):
return output


def get_source_name(node, type_dict):
"""A helper function to get source information of onnx nodes."""
if node.name:
return node.name
else:
op_idx = 0
if node.op_type in type_dict:
op_idx = type_dict[node.op_type] + 1
type_dict[node.op_type] = op_idx
# rewrite name property in case any revisiting occurs to current node
node.name = "{}_{}".format(node.op_type, str(op_idx))
return node.name


def get_source_name_from_parameter(expr, name_sep="."):
"""A helper function to get source information of graph node from parameter."""
if expr.span:
source_name = expr.span.source_name.name
# discard variable/parameter name to get span of op node
# e.g. conv2d.w -> conv2d
if isinstance(expr, _expr.Var):
postfix = f"{name_sep}{expr.name_hint}"
source_name = source_name[: -len(postfix)]
return source_name
return None


def make_parameter_span(source_name_list, name_sep="."):
return name_sep.join(source_name_list)


class OnnxOpConverter(object):
"""A helper class for holding onnx op converters."""

Expand Down Expand Up @@ -2712,10 +2744,13 @@ def _impl_v9(cls, inputs, attr, params):
else:
dtype = get_type(dtype)

in_shape = _op.shape_of(inputs[0])
node_source_name = get_source_name_from_parameter(inputs[0])
# since there exists multi-comsumer for the same expression
# invoke set_span here to prevent expr-rewritten in span-filling stage
in_shape = set_span(_op.shape_of(inputs[0]), node_source_name)
zeros = _op.zeros(in_shape, dtype)

dim = _op.take(in_shape, _op.const(0))
dim = set_span(_op.take(in_shape, _op.const(0)), node_source_name)

indices = _op.arange(_op.const(0), dim, dtype="int32")
ones = _op.full(_op.const(1), _op.reshape(dim, (1,)), dtype=dtype)
Expand Down Expand Up @@ -4128,7 +4163,10 @@ def cond_fn(*loop_inputs):
# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
graph_scope._shape,
graph_scope._dtype,
graph_scope._freeze_params,
graph_scope._op_type_dict,
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()
Expand Down Expand Up @@ -4159,6 +4197,11 @@ def get_var(name, val, scan=False):
]
loop_vars += [get_var(body.input[i + 2].name, v) for i, v in enumerate(loop_deps)]
loop_var_names = [v.name_hint for v in loop_vars]
# get span information of loop body
body_source_name = get_source_name(body, subgraph_scope._op_type_dict)
# set span to inputs of loop body
for i, v in enumerate(loop_vars):
loop_vars[i] = set_span(v, make_parameter_span([v.name_hint, body_source_name]))

num_scan_outputs = len(body.output) - (1 + num_deps)

Expand Down Expand Up @@ -4287,9 +4330,19 @@ def _impl_v1(cls, inputs, attr, params):

# Create graph converters for both branches.
graph_scope = GraphProto.current
then_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
then_graph = GraphProto(
graph_scope._shape,
graph_scope._dtype,
graph_scope._freeze_params,
graph_scope._op_type_dict,
)
then_graph._nodes = graph_scope._nodes.copy()
else_graph = GraphProto(graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params)
else_graph = GraphProto(
graph_scope._shape,
graph_scope._dtype,
graph_scope._freeze_params,
graph_scope._op_type_dict,
)
else_graph._nodes = graph_scope._nodes.copy()

# Convert each branch to a relay expression.
Expand Down Expand Up @@ -4386,7 +4439,10 @@ def cond_fn(*loop_inputs):
# Get the current graph proto and create a clone for the subgraph
graph_scope = GraphProto.current
subgraph_scope = GraphProto(
graph_scope._shape, graph_scope._dtype, graph_scope._freeze_params
graph_scope._shape,
graph_scope._dtype,
graph_scope._freeze_params,
graph_scope._op_type_dict,
)
# Load nodes from outer graph into inner graph.
subgraph_scope._nodes = graph_scope._nodes.copy()
Expand Down Expand Up @@ -4440,6 +4496,12 @@ def get_var(name, val, scan=False):
loop_vars += [
get_var(body.input[i].name, v) for i, v in enumerate(inputs) if i < num_state_inputs
]
# get span information of scan body
body_source_name = get_source_name(body, subgraph_scope._op_type_dict)
# set span to inputs of scan body
for i, v in enumerate(loop_vars):
loop_vars[i] = set_span(v, make_parameter_span([v.name_hint, body_source_name]))

loop_vars += scan_output_vars
body_input_var_names = ["iter"] + [body.input[i].name for i in range(len(body.input))]

Expand Down Expand Up @@ -6197,11 +6259,16 @@ class GraphProto:
at compile time and helps in making models static if certain inputs represent
attributes relay would traditionally consider compile-time constants.

op_type_dict: Dict[str, int]
Dictionary for span filling usage. If the name property of op was not set
op_type_dict will provide an alternative by combining literal op type with
its presenting order

"""

current = None

def __init__(self, shape, dtype, freeze_params=False):
def __init__(self, shape, dtype, freeze_params=False, op_type_dict=None):
self._nodes = {}
self._params = {}
self._inputs = {}
Expand All @@ -6213,6 +6280,7 @@ def __init__(self, shape, dtype, freeze_params=False):
self._dtype = dtype
self.opset = None
self._freeze_params = freeze_params
self._op_type_dict = op_type_dict

def __enter__(self):
self._old_manager = GraphProto.current
Expand Down Expand Up @@ -6365,6 +6433,9 @@ def _construct_nodes(self, graph):
for node in graph.node:
op_name = node.op_type
attr = self._parse_attr(node.attribute)
# Fill in span of inputs
node_source_name = get_source_name(node, self._op_type_dict)
self._set_parameter_span(node, node_source_name)
# Create and populate input list.
inputs = onnx_input()
for i in node.input:
Expand All @@ -6389,6 +6460,8 @@ def _construct_nodes(self, graph):
else:
op = _expr.TupleWrapper(fold_constant(op.astuple()), len(op))

op = set_span(op, node_source_name)

if outputs_num > 1:
# ONNX supports optional outputs for some nodes.
# This block searches for missing outputs in the ONNX graph
Expand Down Expand Up @@ -6427,6 +6500,19 @@ def _construct_nodes(self, graph):
for k, i in zip(list(node_output), range(len(node_output))):
self._nodes[k] = op[i]

def _set_parameter_span(self, node, node_source_name):
for i in node.input:
if i != "":
name = self._renames.get(i, i)
expr = self._nodes.get(name)
# relay.Var -> inputs / params
# relay.Constant -> freezed params / built-in constants
if isinstance(expr, (relay.Var, relay.Constant)):
expr_with_span = set_span(expr, make_parameter_span([node_source_name, name]))
self._nodes[name] = expr_with_span
if name in self._inputs:
self._inputs[name] = expr_with_span

def _parse_value_proto(self, value_proto):
"""Parse ValueProto or raw str."""
try:
Expand Down Expand Up @@ -6506,8 +6592,28 @@ def _fix_outputs(self, op_name, outputs):
return outputs


def export_model(location, graph):
"""Convert the graph to an onnx model and export it to the location."""
import datetime
import os

from onnx import save, helper

if not os.path.exists(location):
os.makedirs(location)
time_stamp = datetime.datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
model = helper.make_model(graph)
save(model, os.path.join(location, "tvm_exported_model_{}.onnx".format(time_stamp)))


def from_onnx(
model, shape=None, dtype="float32", opset=None, freeze_params=True, convert_config=None
model,
shape=None,
dtype="float32",
opset=None,
freeze_params=True,
convert_config=None,
export_node_renamed_model_path=None,
):
"""Convert a ONNX model into an equivalent Relay Function.

Expand Down Expand Up @@ -6553,6 +6659,12 @@ def from_onnx(
True to convert qualified onnx `matmul` to `nn.batch_matmul` strict to NT format
(transpose_a=False, transpose_b=True).

export_node_renamed_model_path : str, optional
Export the node renamed onnx model to the path.
Some models do not contain names in their nodes. During the conversion, if names of nodes
are empty, new names will be assigned based on their op types. The exported model can be the
reference to spans.

Returns
-------
mod : tvm.IRModule
Expand All @@ -6577,7 +6689,7 @@ def from_onnx(
warnings.warn(str(e))
except ImportError:
pass
g = GraphProto(shape, dtype, freeze_params)
g = GraphProto(shape, dtype, freeze_params, op_type_dict={})
graph = model.graph

try:
Expand Down Expand Up @@ -6607,6 +6719,9 @@ def from_onnx(
with g:
mod, params = g.from_onnx(graph, opset)

if export_node_renamed_model_path:
export_model(export_node_renamed_model_path, graph)

if freeze_params:
mod = relay.transform.DynamicToStatic()(mod)

Expand Down
Loading