From 7f8119eddb1c6665078a92cb84f2736f4b091aa0 Mon Sep 17 00:00:00 2001 From: Yong WU <55wuyong@163.com> Date: Mon, 11 Feb 2019 14:32:32 -0800 Subject: [PATCH] [relay][frontend] TensorFlow saved model support --- python/tvm/relay/frontend/__init__.py | 1 + python/tvm/relay/frontend/tensorflow.py | 59 +++++-- .../tvm/relay/frontend/tensorflow_parser.py | 153 ++++++++++++++++++ 3 files changed, 198 insertions(+), 15 deletions(-) create mode 100644 python/tvm/relay/frontend/tensorflow_parser.py diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index dee3999ad3f10..3d2c512556a09 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -14,3 +14,4 @@ from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow +from .tensorflow_parser import TFParser diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 82b4c5b9ca370..d1e5a73ad7d31 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -4,6 +4,7 @@ from __future__ import print_function import logging +import warnings # Numpy support import numpy as np @@ -411,7 +412,7 @@ def _impl(inputs, attr, params): def _decode_image(): def _impl(inputs, attr, params): # Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer. - print("DecodeJpeg: It's a pass through, please handle preprocessing before input") + warnings.warn("DecodeJpeg: It's a pass through, please handle preprocessing before input") return inputs[0] return _impl @@ -1172,6 +1173,7 @@ class GraphProto(object): def __init__(self): self._nodes = {} self._params = {} + self._input_shapes = {} self._output_shapes = {} self._num_param = 0 self._num_rnn_layer = False @@ -1223,36 +1225,55 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): raise NotImplementedError( \ "The following operators are not implemented: {}".format(missing_operators)) + for node in graph.node: + if node.op == 'Placeholder': + if shape and node.name in shape: + self._input_shapes[node.name] = list(shape[node.name]) + continue + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(node.attr['shape'].shape) + for idx, dim in enumerate(self._input_shapes[node.name]): + if dim < 0: + self._input_shapes[node.name][idx] = 1 + warnings.warn("Use 1 instead of -1 in shape of operator %s." + % node.name) + + # Ignore user's input shape for Non placeholder + elif node.op == 'Const': + tensor_value = node.attr['value'].tensor + self._input_shapes[node.name] = \ + tensor_util.TensorShapeProtoToList(tensor_value.tensor_shape) + if shape and node.name in shape: + warnings.warn("Ignore the passed shape. Shape in graphdef " + "will be used for operator %s." % node.name) + # Parse the nodes to re-create TF graph using Relay operators. for node in graph.node: - # Tensorflow doesn't have seperate list for params extraction. + # Tensorflow doesn't have separate list for params extraction. # Operator name 'Const' is treated as a parameter to build params dict. input_shapes = {} attr = self._parse_attr(node.attr) - #Variable converted to Const will not have only value attr + # Variable converted to Const will not have only value attr if 'value' in attr and node.op == 'Const': - tensor_value = attr['value'] - self._output_shapes[node.name] = \ - [tensor_util.TensorShapeProtoToList( \ - tensor_value.tensor_shape)] + self._output_shapes[node.name] = [self._input_shapes[node.name]] + elif shape and node.name in shape: + # Give priority to user argument. + self._output_shapes[node.name] = [shape[node.name]] elif '_output_shapes' in attr: self._output_shapes[node.name] = \ [tensor_util.TensorShapeProtoToList(tshape) \ for tshape in attr['_output_shapes']] - elif shape: + else: # Keep the list indexable to avoid key error. # Actual value will be filled after node creation. self._output_shapes[node.name] = [None] - else: - raise NotImplementedError( \ - "Please freeze the graph with add_shapes=True") if node.op == "Placeholder": - self._output_shapes[node.name] = [shape[node.name]] + self._output_shapes[node.name] = [self._input_shapes[node.name]] self._nodes[node.name] = [_expr.var(node.name, - shape=self._output_shapes[node.name][0], + shape=self._input_shapes[node.name], dtype=attr['dtype'].name)] elif node.op == "Const": @@ -1268,7 +1289,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): else: # Pass the parsed shapes instead - attr["_output_shapes"] = self._output_shapes[node.name] + attr["_output_shapes"] = output_shapes = self._output_shapes[node.name] # Pass the node name too in attr attr["_node_name"] = node.name @@ -1295,7 +1316,7 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): op = self._convert_operator(node.op, inputs, attr, graph) - # Check is op is converted to param + # Check if op is converted to param if isinstance(op, np.ndarray): self._params[node.name] = tvm.nd.array(op) op = [_expr.var(node.name, @@ -1311,6 +1332,14 @@ def from_tensorflow(self, graph, layout="NHWC", shape=None, outputs=None): self._nodes[node.name] = op + # Infer shapes even without specifying "add_shapes=True" + if output_shapes == [None]: + out_type = ir_pass.infer_type(self._nodes[node.name][0]) + self._output_shapes[node.name] = [get_const_tuple(out_type.checked_type.shape)] + + if self._output_shapes[node.name] and shape and node.name in shape: + assert self._output_shapes[node.name] == list(shape[node.name]) + # Infer shapes if passed explicitely node_output = self._nodes[node.name] out_type = ir_pass.infer_type(node_output[0]) diff --git a/python/tvm/relay/frontend/tensorflow_parser.py b/python/tvm/relay/frontend/tensorflow_parser.py new file mode 100644 index 0000000000000..9b745c9d02c9a --- /dev/null +++ b/python/tvm/relay/frontend/tensorflow_parser.py @@ -0,0 +1,153 @@ +"""TF: Tensorflow parser""" +from __future__ import absolute_import as _abs +from __future__ import print_function +import os +from tensorflow.core.framework import graph_pb2 +from tvm.contrib import util + + +class TFParser(object): + """A Wrapper to handle tensorflow models parsing + TensorFlow is needed + ``` + parser = TfParser(model_dir) + graph = parser.parse() + ``` + Parameters + ---------- + model_dir : tensorflow frozen pb file or a directory that contains saved + model or checkpoints. + """ + + def __init__(self, model_dir): + self._tmp_dir = util.tempdir() + self._model_dir = model_dir + self._graph = graph_pb2.GraphDef() + + def _set_graph(self, graph): + """Set Graph""" + self._graph = graph + + def _get_graph(self): + """Get Graph""" + return self._graph + + def _load_pb_file(self): + """Load single pb file""" + graph = self._get_graph() + with open(self._model_dir, "rb") as f: + graph.ParseFromString(f.read()) + return graph + + def _get_tag_set(self): + """Return the tag set of saved model, multiple metagraphs are not supported""" + try: + from tensorflow.contrib.saved_model.python.saved_model import reader + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import saved_model.reader which is " + "required to get tag set from saved model.") + tag_sets = reader.get_saved_model_tag_sets(self._model_dir) + return tag_sets[0] + + def _get_output_names(self): + """Return the concatenated output names""" + try: + import tensorflow as tf + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import tensorflow which is " + "required to restore from saved model.") + tags = self._get_tag_set() + with tf.Session() as sess: + meta_graph_def = tf.saved_model.loader.load(sess, + tags, + self._model_dir) + output_names = set() + for k in meta_graph_def.signature_def.keys(): + outputs_tensor_info = meta_graph_def.signature_def[k].outputs + for output_tensor in outputs_tensor_info.values(): + output_names.add(output_tensor.name) + output_names = [i.replace(":0", "") for i in output_names] + return ",".join(output_names) + + def _load_saved_model(self): + """Load the tensorflow saved model.""" + try: + from tensorflow.python.tools import freeze_graph + from tensorflow.python.framework import ops + from tensorflow.python.framework import graph_util + except ImportError: + raise ImportError( + "InputConfiguration: Unable to import tensorflow which is " + "required to restore from saved model.") + + saved_model_dir = self._model_dir + output_graph_filename = self._tmp_dir.relpath("tf_frozen_model.pb") + input_saved_model_dir = saved_model_dir + output_node_names = self._get_output_names() + + input_binary = False + input_saver_def_path = False + restore_op_name = None + filename_tensor_name = None + clear_devices = True + input_meta_graph = False + checkpoint_path = None + input_graph_filename = None + saved_model_tags = ",".join(self._get_tag_set()) + + freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, + input_binary, checkpoint_path, output_node_names, + restore_op_name, filename_tensor_name, + output_graph_filename, clear_devices, "", "", "", + input_meta_graph, input_saved_model_dir, + saved_model_tags) + + with ops.Graph().as_default(): + output_graph_def = graph_pb2.GraphDef() + with open(output_graph_filename, "rb") as f: + output_graph_def.ParseFromString(f.read()) + output_graph_def = graph_util.remove_training_nodes(output_graph_def) + return output_graph_def + + def _load_ckpt(self): + """TODO: Load checkpoint model.""" + raise RuntimeError("InputConfiguration: Loading tf checkpoint model is " + "not supported yet.") + + def parse(self): + """Parse tensorflow models: checkpoints, saved models, and single pb + file. + """ + graph = None + + if os.path.isdir(self._model_dir): + ckpt = os.path.join(self._model_dir, "checkpoint") + if not os.path.isfile(ckpt): + if not os.path.isdir(os.path.join(self._model_dir, "variables")): + raise RuntimeError("InputConfiguration: Invalid model path.") + graph = self._load_saved_model() + else: + graph = self._load_ckpt() + elif os.path.isfile(self._model_dir): + # Only .pb or .pbtxt is a valid suffix name. + if self._model_dir.endswith(".pb") or \ + self._model_dir.endswith(".pbtxt"): + cur_dir = os.path.dirname(self._model_dir) + else: + raise RuntimeError("InputConfiguration: Invalid model format.") + + # It is a saved model if `variables` directory is present at the + # same directory with the pb or pbtxt file. + if os.path.isdir(os.path.join(cur_dir, "variables")): + self._model_dir = cur_dir + graph = self._load_saved_model() + else: + graph = self._load_pb_file() + else: + raise RuntimeError("InputConfiguration: Unrecognized model " + "file or path.") + + self._set_graph(graph) + return graph